diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 31dfac6f53a2..e616c3f3f53d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -46,7 +46,7 @@ jobs: - name: Load benchmark data from cache id: cache-data - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: py-polars/tests/benchmark/G1_1e7_1e2_5_0.csv key: benchmark-data @@ -66,7 +66,7 @@ jobs: - name: Save benchmark data in cache if: github.ref_name == 'main' - uses: actions/cache/save@v3 + uses: actions/cache/save@v4 with: path: py-polars/tests/benchmark/G1_1e7_1e2_5_0.csv key: ${{ steps.cache-data.outputs.cache-primary-key }} @@ -87,9 +87,7 @@ jobs: env: RUSTFLAGS: -C embed-bitcode -D warnings working-directory: py-polars - run: | - source activate - maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + run: maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native - name: Run H2O AI database benchmark - on strings working-directory: py-polars/tests/benchmark diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml new file mode 100644 index 000000000000..7e8c2353c5e5 --- /dev/null +++ b/.github/workflows/codecov.yml @@ -0,0 +1,107 @@ +name: Code coverage + +on: + pull_request: + paths: + - '**.rs' + - '**.py' + - .github/workflows/codecov.yml + push: + branches: + - main + paths: + - '**.rs' + - '**.py' + - .github/workflows/codecov.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + working-directory: py-polars + shell: bash + +jobs: + coverage: + name: Code Coverage + runs-on: macos-latest + env: + RUSTFLAGS: '-C instrument-coverage --cfg=coverage --cfg=coverage_nightly --cfg=trybuild_no_target' + RUST_BACKTRACE: 1 + LLVM_PROFILE_FILE: '/Users/runner/work/polars/polars/target/polars-%p-%3m.profraw' + CARGO_LLVM_COV: 1 + CARGO_LLVM_COV_SHOW_ENV: 1 + CARGO_LLVM_COV_TARGET_DIR: '/Users/runner/work/polars/polars/target' + # Workaround for issue compiling libz-ng-sys crate on macOS (resource temporarily unavailable) + # See: https://github.com/pola-rs/polars/pull/14715 + CMAKE_BUILD_PARALLEL_LEVEL: 10 + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Create virtual environment + run: | + python -m venv .venv + echo "$GITHUB_WORKSPACE/py-polars/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: pip install -r requirements-dev.txt + + - name: Set up Rust + run: rustup component add llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Prepare coverage + run: cargo llvm-cov clean --workspace + + - name: Run tests + run: > + cargo test --all-features + -p polars-arrow + -p polars-core + -p polars-io + -p polars-lazy + -p polars-ops + -p polars-plan + -p polars-row + -p polars-sql + -p polars-time + -p polars-utils + + - name: Run Rust integration tests + run: cargo test --all-features -p polars --test it + + - name: Install Polars + run: maturin develop + + - name: Run Python tests + run: pytest --cov -n auto --dist loadgroup -m "not benchmark and not docs" --cov-report xml:main.xml + continue-on-error: true + + - name: Run Python tests - async reader + env: + POLARS_FORCE_ASYNC: 1 + run: pytest --cov -m "not benchmark and not docs" tests/unit/io/ --cov-report xml:async.xml + continue-on-error: true + + - name: Report coverage + run: cargo llvm-cov report --lcov --output-path coverage.lcov + + - name: Upload coverage information + uses: codecov/codecov-action@v4 + with: + files: py-polars/coverage.lcov,py-polars/main.xml,py-polars/async.xml + name: macos + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml index 9e66576158e6..cddd3fcbced3 100644 --- a/.github/workflows/docs-global.yml +++ b/.github/workflows/docs-global.yml @@ -82,12 +82,10 @@ jobs: - name: Install Polars working-directory: py-polars - run: | - source activate - maturin develop + run: maturin develop - name: Set up Graphviz - uses: ts-graphviz/setup-graphviz@v1 + uses: ts-graphviz/setup-graphviz@v2 - name: Build documentation env: diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index fb2ee2c8e4f2..d3383dc164fc 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -15,4 +15,4 @@ jobs: - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos - uses: crate-ci/typos@v1.16.21 + uses: crate-ci/typos@v1.17.2 diff --git a/.github/workflows/lint-rust.yml b/.github/workflows/lint-rust.yml index 9ac00ca53886..cb974eb8bd9d 100644 --- a/.github/workflows/lint-rust.yml +++ b/.github/workflows/lint-rust.yml @@ -44,7 +44,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy with all features enabled - run: cargo clippy --workspace --all-targets --all-features --locked -- -D warnings + run: cargo clippy --workspace --all-targets --all-features --locked -- -D warnings -D clippy::dbg_macro # Default feature set should compile on the stable toolchain clippy-stable: @@ -64,7 +64,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy - run: cargo clippy --all-targets --locked -- -D warnings + run: cargo clippy --all-targets --locked -- -D warnings -D clippy::dbg_macro rustfmt: if: github.ref_name != 'main' diff --git a/.github/workflows/pr-labeler.yml b/.github/workflows/pr-labeler.yml index 13b82c26e61e..7c9be45095fe 100644 --- a/.github/workflows/pr-labeler.yml +++ b/.github/workflows/pr-labeler.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Label pull request - uses: release-drafter/release-drafter@v5 + uses: release-drafter/release-drafter@v6 with: disable-releaser: true env: diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml index 03f1aca65d07..84229ef07920 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Draft Rust release - uses: release-drafter/release-drafter@v5 + uses: release-drafter/release-drafter@v6 with: config-name: release-drafter-rust.yml commitish: ${{ inputs.sha || github.sha }} @@ -29,7 +29,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Draft Python release - uses: release-drafter/release-drafter@v5 + uses: release-drafter/release-drafter@v6 with: config-name: release-drafter-python.yml commitish: ${{ inputs.sha || github.sha }} diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index a686835e301d..87c5803c5a28 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -108,6 +108,12 @@ jobs: with: swap-size-gb: 10 + # Workaround for issue compiling libz-ng-sys crate on macOS (resource temporarily unavailable) + # See: https://github.com/pola-rs/polars/pull/14715 + - name: Set cmake build threads for MacOS + if: matrix.os == 'macos-latest' + run: echo "CMAKE_BUILD_PARALLEL_LEVEL=10" >> $GITHUB_ENV + - name: Set up Python uses: actions/setup-python@v5 with: @@ -235,7 +241,7 @@ jobs: - name: Create GitHub release id: github-release - uses: release-drafter/release-drafter@v5 + uses: release-drafter/release-drafter@v6 with: config-name: release-drafter-python.yml name: Python Polars ${{ steps.version.outputs.version }} @@ -263,7 +269,7 @@ jobs: - name: Trigger other workflows related to the release if: inputs.dry-run == false && steps.version.outputs.is_prerelease == 'false' - uses: peter-evans/repository-dispatch@v2 + uses: peter-evans/repository-dispatch@v3 with: event-type: python-release client-payload: > diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 07f7f2eef8bd..f6ac631177c8 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -3,6 +3,7 @@ name: Test Python on: pull_request: paths: + - Cargo.lock - py-polars/** - docs/src/python/** - crates/** @@ -11,6 +12,7 @@ on: branches: - main paths: + - Cargo.lock - crates/** - docs/src/python/** - py-polars/** @@ -49,7 +51,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Set up Graphviz - uses: ts-graphviz/setup-graphviz@v1 + uses: ts-graphviz/setup-graphviz@v2 - name: Create virtual environment env: @@ -71,16 +73,14 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Install Polars - run: | - source activate - maturin develop + run: maturin develop - name: Run doctests if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' run: | python tests/docs/run_doctest.py pytest tests/docs/test_user_guide.py -m docs - + - name: Run tests and report coverage if: github.ref_name != 'main' env: @@ -92,7 +92,9 @@ jobs: - name: Run tests async reader tests if: github.ref_name != 'main' && matrix.os != 'windows-latest' - run: POLARS_FORCE_ASYNC=1 pytest -m "not benchmark and not docs" tests/unit/io/ + env: + POLARS_FORCE_ASYNC: 1 + run: pytest -m "not benchmark and not docs" tests/unit/io/ - name: Check import without optional dependencies if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' diff --git a/.gitignore b/.gitignore index 8a306d27c861..525e4a5301e5 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ __pycache__/ .coverage # Rust +.cargo/ target/ # Project diff --git a/Cargo.lock b/Cargo.lock index 835ef1a03cbc..a8e0514c0e16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,9 +25,9 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" dependencies = [ "cfg-if", "const-random", @@ -90,15 +90,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "apache-avro" @@ -142,48 +142,48 @@ checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" [[package]] name = "arrow-array" -version = "49.0.0" +version = "50.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" +checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" dependencies = [ "ahash", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half 2.3.1", + "half", "hashbrown 0.14.3", "num", ] [[package]] name = "arrow-buffer" -version = "49.0.0" +version = "50.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" +checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" dependencies = [ "bytes", - "half 2.3.1", + "half", "num", ] [[package]] name = "arrow-data" -version = "49.0.0" +version = "50.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" +checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" dependencies = [ "arrow-buffer", "arrow-schema", - "half 2.3.1", + "half", "num", ] [[package]] name = "arrow-schema" -version = "49.0.0" +version = "50.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" +checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" [[package]] name = "arrow2" @@ -224,7 +224,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -235,7 +235,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -277,12 +277,11 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.1.1" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11382bd8ac4c6c182a9775990935f96c916a865f1414486595f18eb8cfa9d90b" +checksum = "3182c19847238b50b62ae0383a6dbfc14514e552eb5e307e1ea83ccf5840b8a6" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-sdk-sso", "aws-sdk-ssooidc", @@ -297,7 +296,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http", + "http 0.2.11", "hyper", "ring", "time", @@ -308,9 +307,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.1.1" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a1629320d319dc715c6189b172349186557e209d2a7b893ff3d14efd33a47c" +checksum = "e5635d8707f265c773282a22abe1ecd4fbe96a8eb2f0f14c0796f8016f11a41a" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -318,30 +317,13 @@ dependencies = [ "zeroize", ] -[[package]] -name = "aws-http" -version = "0.60.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e4199d5d62ab09be6a64650c06cc5c4aa45806fed4c74bc4a5c8eaf039a6fa" -dependencies = [ - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-types", - "bytes", - "http", - "http-body", - "pin-project-lite", - "tracing", -] - [[package]] name = "aws-runtime" -version = "1.1.1" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87116d357c905b53f1828d15366363fd27b330a0393cbef349e653f686d36bad" +checksum = "6f82b9ae2adfd9d6582440d0eeb394c07f74d21b4c0cc72bdb73735c9e1a9c0e" dependencies = [ "aws-credential-types", - "aws-http", "aws-sigv4", "aws-smithy-async", "aws-smithy-eventstream", @@ -349,21 +331,23 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", + "bytes", "fastrand", - "http", + "http 0.2.11", + "http-body", "percent-encoding", + "pin-project-lite", "tracing", "uuid", ] [[package]] name = "aws-sdk-s3" -version = "1.11.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21392b29994de019a7059af5eab144ea49d572dd52863d8e10537267f59f998c" +checksum = "5076637347e7d0218e61facae853110682ae58efabd2f4e2a9e530c203d5fa7b" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-sigv4", "aws-smithy-async", @@ -377,7 +361,7 @@ dependencies = [ "aws-smithy-xml", "aws-types", "bytes", - "http", + "http 0.2.11", "http-body", "once_cell", "percent-encoding", @@ -388,12 +372,11 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.9.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9d9a8ac4cdb8df39f9777fd41e15a9ae0d0b622b00909ae0322b4d2f9e6ac8" +checksum = "ca7e8097448832fcd22faf6bb227e97d76b40e354509d1307653a885811c7151" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -403,7 +386,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http", + "http 0.2.11", "once_cell", "regex-lite", "tracing", @@ -411,12 +394,11 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.9.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ba4a42aa91acecd5ca43b330b5c8eb7f8808d720b6a6f796a35faa302fc73d" +checksum = "a75073590e23d63044606771afae309fada8eb10ded54a1ce4598347221d3fef" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -426,7 +408,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http", + "http 0.2.11", "once_cell", "regex-lite", "tracing", @@ -434,12 +416,11 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.9.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3c7c3dcec7cccd24a13953eedf0f2964c2d728d22112744274cf0098ad2e35" +checksum = "650e4aaae41547151dea4d8142f7ffcc8ab8ba76d5dccc8933936ef2102c3356" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -450,7 +431,7 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", - "http", + "http 0.2.11", "once_cell", "regex-lite", "tracing", @@ -458,9 +439,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.1.1" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d222297ca90209dc62245f0a490355795f29de362eb5c19caea4f7f55fe69078" +checksum = "404c64a104188ac70dd1684718765cb5559795458e446480e41984e68e57d888" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -472,7 +453,8 @@ dependencies = [ "form_urlencoded", "hex", "hmac", - "http", + "http 0.2.11", + "http 1.0.0", "once_cell", "p256", "percent-encoding", @@ -486,9 +468,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.1" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9f65000917e3aa94c259d67fe01fa9e4cd456187d026067d642436e6311a81" +checksum = "fcf7f09a27286d84315dfb9346208abb3b0973a692454ae6d0bc8d803fcce3b4" dependencies = [ "futures-util", "pin-project-lite", @@ -497,9 +479,9 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.60.1" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c2a63681f82fb85ca58d566534b7dc619c782fee0c61c1aa51e2b560c21cb4f" +checksum = "0fd4b66f2a8e7c84d7e97bda2666273d41d2a2e25302605bcf906b7b2661ae5e" dependencies = [ "aws-smithy-http", "aws-smithy-types", @@ -507,7 +489,7 @@ dependencies = [ "crc32c", "crc32fast", "hex", - "http", + "http 0.2.11", "http-body", "md-5", "pin-project-lite", @@ -518,9 +500,9 @@ dependencies = [ [[package]] name = "aws-smithy-eventstream" -version = "0.60.1" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a85e16fa903c70c49ab3785e5f4ac2ad2171b36e0616f321011fa57962404bb6" +checksum = "e6363078f927f612b970edf9d1903ef5cef9a64d1e8423525ebb1f0a1633c858" dependencies = [ "aws-smithy-types", "bytes", @@ -529,9 +511,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.1" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4e816425a6b9caea4929ac97d0cb33674849bd5f0086418abc0d02c63f7a1bf" +checksum = "b6ca214a6a26f1b7ebd63aa8d4f5e2194095643023f9608edf99a58247b9d80d" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -539,7 +521,7 @@ dependencies = [ "bytes", "bytes-utils", "futures-core", - "http", + "http 0.2.11", "http-body", "once_cell", "percent-encoding", @@ -550,18 +532,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.1" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ab3f6d49e08df2f8d05e1bb5b68998e1e67b76054d3c43e7b954becb9a5e9ac" +checksum = "1af80ecf3057fb25fe38d1687e94c4601a7817c6a1e87c1b0635f7ecb644ace5" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.1" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f94a7a3aa509ff9e8b8d80749851d04e5eee0954c43f2e7d6396c4740028737" +checksum = "eb27084f72ea5fc20033efe180618677ff4a2f474b53d84695cfe310a6526cbc" dependencies = [ "aws-smithy-types", "urlencoding", @@ -569,9 +551,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.1" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da5b0a3617390e769576321816112f711c13d7e1114685e022505cf51fe5e48" +checksum = "fbb5fca54a532a36ff927fbd7407a7c8eb9c3b4faf72792ba2965ea2cad8ed55" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -580,7 +562,7 @@ dependencies = [ "bytes", "fastrand", "h2", - "http", + "http 0.2.11", "http-body", "hyper", "hyper-rustls", @@ -594,14 +576,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.1.1" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2404c9eb08bfe9af255945254d9afc69a367b7ee008b8db75c05e3bca485fc65" +checksum = "22389cb6f7cac64f266fb9f137745a9349ced7b47e0d2ba503e9e40ede4f7060" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", - "http", + "http 0.2.11", + "http 1.0.0", "pin-project-lite", "tokio", "tracing", @@ -610,15 +593,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.1" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8136605d14ac88f57dc3a693a9f8a4eab4a3f52bc03ff13746f0cd704e97" +checksum = "f081da5481210523d44ffd83d9f0740320050054006c719eae0232d411f024d3" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", - "http", + "http 0.2.11", "http-body", "itoa", "num-integer", @@ -633,24 +616,24 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.1" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e8f03926587fc881b12b102048bb04305bf7fb8c83e776f0ccc51eaa2378263" +checksum = "0fccd8f595d0ca839f9f2548e66b99514a85f92feb4c01cf2868d93eb4888a42" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.1" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5d5ee29077e0fcd5ddd0c227b521a33aaf02434b7cdba1c55eec5c1f18ac47" +checksum = "8fbb5d48aae496f628e7aa2e41991dd4074f606d9e3ade1ce1059f293d40f9a2" dependencies = [ "aws-credential-types", "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "http", + "http 0.2.11", "rustc_version", "tracing", ] @@ -678,9 +661,9 @@ checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" [[package]] name = "base64" -version = "0.21.5" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64-simd" @@ -715,9 +698,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" dependencies = [ "serde", ] @@ -765,15 +748,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.14.0" +version = "3.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +checksum = "8ea184aa71bb362a1157c896979544cc23974e08fd265f29ea96b59f0b4a555b" [[package]] name = "bytemuck" -version = "1.14.0" +version = "1.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" dependencies = [ "bytemuck_derive", ] @@ -786,7 +769,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -834,11 +817,10 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "3286b845d0fccbdd15af433f61c5970e711987036cb468f437ff6badd70f4e24" dependencies = [ - "jobserver", "libc", ] @@ -850,22 +832,22 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "5bc015644b92d5890fab7489e49d21f879d5c990186827d42ec511919404f38b" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.48.5", + "windows-targets 0.52.3", ] [[package]] name = "chrono-tz" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d7b79e99bfaa0d47da0687c43aa3b7381938a62ad3a6498599039321f660b7" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" dependencies = [ "chrono", "chrono-tz-build", @@ -885,9 +867,9 @@ dependencies = [ [[package]] name = "ciborium" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" dependencies = [ "ciborium-io", "ciborium-ll", @@ -896,34 +878,34 @@ dependencies = [ [[package]] name = "ciborium-io" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" [[package]] name = "ciborium-ll" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" dependencies = [ "ciborium-io", - "half 1.8.2", + "half", ] [[package]] name = "clap" -version = "4.4.12" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d" +checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.4.12" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" +checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" dependencies = [ "anstyle", "clap_lex", @@ -931,9 +913,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] name = "cmake" @@ -1009,9 +991,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" dependencies = [ "libc", ] @@ -1033,18 +1015,18 @@ checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" [[package]] name = "crc32c" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8f48d60e5b4d2c53d5c2b1d8a58c849a70ae5e5509b08a48d047e3b65714a74" +checksum = "89254598aa9b9fa608de44b3ae54c810f0f06d755e24c50177f1f8f31ff50ce2" dependencies = [ "rustc_version", ] [[package]] name = "crc32fast" -version = "1.3.2" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" dependencies = [ "cfg-if", ] @@ -1087,54 +1069,46 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a9b73a36529d9c47029b9fb3a6f0ea3cc916a261195352ba19e770fc1748b2" +checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.17" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-queue" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc6598521bb5a83d491e8c1fe51db7296019d2ca3cb93cc6c2a20369a4d78a2" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.18" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" -dependencies = [ - "cfg-if", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crossterm" @@ -1142,7 +1116,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "crossterm_winapi", "libc", "parking_lot", @@ -1258,9 +1232,9 @@ dependencies = [ [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "elliptic-curve" @@ -1300,7 +1274,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -1460,7 +1434,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -1505,9 +1479,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "js-sys", @@ -1524,11 +1498,11 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "git2" -version = "0.18.1" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf97ba92db08df386e10c8ede66a2a0369bd277090afd8710e19e38de9ec0cd" +checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "libc", "libgit2-sys", "log", @@ -1563,7 +1537,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.11", "indexmap", "slab", "tokio", @@ -1571,12 +1545,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - [[package]] name = "half" version = "2.3.1" @@ -1632,9 +1600,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "379dada1584ad501b383485dd706b8afb7a70fcbc7f4da7d780638a5a6124a60" [[package]] name = "hex" @@ -1671,6 +1639,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1678,7 +1657,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.11", "pin-project-lite", ] @@ -1711,7 +1690,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.11", "http-body", "httparse", "httpdate", @@ -1731,7 +1710,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http", + "http 0.2.11", "hyper", "log", "rustls", @@ -1742,9 +1721,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.59" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1775,9 +1754,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.1.0" +version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1792,9 +1771,9 @@ checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "inventory" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8573b2b1fb643a372c73b23f4da5f888677feef3305146d68a539250a9bccc7" +checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" [[package]] name = "ipnet" @@ -1804,12 +1783,12 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" dependencies = [ "hermit-abi", - "rustix", + "libc", "windows-sys 0.52.0", ] @@ -1824,9 +1803,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] @@ -1863,20 +1842,11 @@ dependencies = [ "libc", ] -[[package]] -name = "jobserver" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" -version = "0.3.66" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -1963,9 +1933,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.151" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libflate" @@ -2013,9 +1983,9 @@ dependencies = [ [[package]] name = "libgit2-sys" -version = "0.16.1+1.7.1" +version = "0.16.2+1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2a2bb3680b094add03bb3732ec520ece34da31a8cd2d633d1389d0f0fb60d0c" +checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8" dependencies = [ "cc", "libc", @@ -2051,9 +2021,9 @@ dependencies = [ [[package]] name = "libz-ng-sys" -version = "1.1.12" +version = "1.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dd9f43e75536a46ee0f92b758f6b63846e594e86638c61a9251338a65baea63" +checksum = "c6409efc61b12687963e602df8ecf70e8ddacf95bc6576bcf16e3ac6328083c5" dependencies = [ "cmake", "libc", @@ -2061,9 +2031,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.12" +version = "1.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d97137b25e321a73eef1418d1d5d2eda4d77e12813f8e6dead84bc52c5870a7b" +checksum = "037731f5d3aaa87a5675e895b63ddff1a87624bc29f77004ea829809654e48f6" dependencies = [ "cc", "libc", @@ -2073,9 +2043,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -2174,9 +2144,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] @@ -2272,28 +2242,33 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-iter" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" dependencies = [ "autocfg", "num-integer", @@ -2314,9 +2289,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -2369,14 +2344,14 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools 0.12.0", + "itertools 0.12.1", "parking_lot", "percent-encoding", "quick-xml", "rand", "reqwest", "ring", - "rustls-pemfile 2.0.0", + "rustls-pemfile 2.1.0", "serde", "serde_json", "snafu", @@ -2531,9 +2506,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "planus" @@ -2577,15 +2552,27 @@ name = "polars" version = "0.37.0" dependencies = [ "ahash", + "apache-avro", + "avro-schema", + "either", + "ethnum", + "futures", "getrandom", + "polars-arrow", "polars-core", + "polars-error", "polars-io", "polars-lazy", "polars-ops", + "polars-parquet", "polars-plan", "polars-sql", "polars-time", + "polars-utils", + "proptest", "rand", + "tokio", + "tokio-util", "version_check", ] @@ -2594,7 +2581,6 @@ name = "polars-arrow" version = "0.37.0" dependencies = [ "ahash", - "apache-avro", "arrow-array", "arrow-buffer", "arrow-data", @@ -2663,10 +2649,12 @@ name = "polars-compute" version = "0.37.0" dependencies = [ "bytemuck", + "either", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "strength_reduce", "version_check", ] @@ -2677,7 +2665,7 @@ dependencies = [ "ahash", "arrow-array", "bincode", - "bitflags 2.4.1", + "bitflags 2.4.2", "bytemuck", "chrono", "chrono-tz", @@ -2806,7 +2794,7 @@ name = "polars-lazy" version = "0.37.0" dependencies = [ "ahash", - "bitflags 2.4.1", + "bitflags 2.4.2", "futures", "glob", "once_cell", @@ -2910,6 +2898,7 @@ dependencies = [ "rayon", "smartstring", "tokio", + "uuid", "version_check", ] @@ -3005,6 +2994,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "powerfmt" version = "0.2.0" @@ -3019,9 +3014,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.74" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -3032,7 +3027,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "lazy_static", "num-traits", "rand", @@ -3067,7 +3062,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.20.6" +version = "0.20.11" dependencies = [ "ahash", "built", @@ -3078,6 +3073,7 @@ dependencies = [ "libc", "mimalloc", "ndarray", + "num-traits", "numpy", "once_cell", "polars", @@ -3097,9 +3093,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ "cfg-if", "indoc", @@ -3107,6 +3103,7 @@ dependencies = [ "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -3115,9 +3112,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "target-lexicon", @@ -3131,9 +3128,9 @@ checksum = "be6d574e0f8cab2cdd1eeeb640cbf845c974519fa9e9b62fa9c08ecece0ca5de" [[package]] name = "pyo3-ffi" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -3141,26 +3138,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -3266,9 +3264,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" dependencies = [ "either", "rayon-core", @@ -3276,9 +3274,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -3310,14 +3308,14 @@ checksum = "5fddb4f8d99b0a2ebafc65a87a69a7b9875e4b1ae1f00db265d300ef7f28bccc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] name = "regex" -version = "1.10.2" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", @@ -3327,9 +3325,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -3356,9 +3354,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.23" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ "base64", "bytes", @@ -3366,7 +3364,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.11", "http-body", "hyper", "hyper-rustls", @@ -3383,6 +3381,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", "system-configuration", "tokio", "tokio-rustls", @@ -3409,16 +3408,17 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", + "cfg-if", "getrandom", "libc", "spin", "untrusted", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3450,11 +3450,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.28" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -3496,9 +3496,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +checksum = "3c333bb734fcdedcea57de1602543590f545f127dc8b533324318fd492c5c70b" dependencies = [ "base64", "rustls-pki-types", @@ -3506,9 +3506,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.1.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" +checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7" [[package]] name = "rustls-webpki" @@ -3528,9 +3528,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -3651,9 +3651,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" dependencies = [ "serde", ] @@ -3666,29 +3666,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.194" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b114498256798c94a0689e1a15fec6005dee8ac1f41de56404b67afc2a4b773" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.194" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3385e45322e8f9931410f01b3031ec534c3947d0e94c18049af4d9f9907d4e0" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.110" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fbd975230bada99c8bb618e0c365c2eefa219158d5c6c29610fd09ff1833257" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "indexmap", "itoa", @@ -3760,9 +3760,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.13.4" +version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a3720326b20bf5b95b72dbbd133caae7e0dcf71eae8f6e6656e71a7e5c9aaa" +checksum = "2faf8f101b9bc484337a6a6b0409cf76c139f2fb70a9e3aee6b6774be7bfbf76" dependencies = [ "ahash", "getrandom", @@ -3799,9 +3799,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "smartstring" @@ -3845,12 +3845,12 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3921,7 +3921,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -3943,20 +3943,26 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.46" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sysinfo" -version = "0.30.3" +version = "0.30.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba2dbd2894d23b2d78dae768d85e323b557ac3ac71a5d917a31536d8f77ebada" +checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" dependencies = [ "cfg-if", "core-foundation-sys", @@ -3995,50 +4001,50 @@ checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" [[package]] name = "target-lexicon" -version = "0.12.12" +version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", "rustix", "windows-sys 0.52.0", ] [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] name = "time" -version = "0.3.31" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", + "num-conv", "powerfmt", "serde", "time-core", @@ -4053,10 +4059,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -4096,9 +4103,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", @@ -4120,7 +4127,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -4207,7 +4214,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -4242,7 +4249,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] @@ -4259,9 +4266,9 @@ checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" [[package]] name = "unicode-bidi" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" @@ -4271,9 +4278,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] @@ -4289,9 +4296,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" @@ -4330,18 +4337,19 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "uuid" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ + "getrandom", "serde", ] [[package]] name = "value-trait" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea87257cfcbedcb9444eda79c59fdfea71217e6305afee8ee33f500375c2ac97" +checksum = "dad8db98c1e677797df21ba03fca7d3bf9bec3ca38db930954e4fe6e1ea27eb4" dependencies = [ "float-cmp", "halfbrown", @@ -4394,9 +4402,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.89" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -4404,24 +4412,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.89" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.39" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -4431,9 +4439,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.89" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4441,28 +4449,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.89" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.89" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "wasm-streams" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" dependencies = [ "futures-util", "js-sys", @@ -4473,9 +4481,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.66" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", @@ -4519,7 +4527,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.0", + "windows-targets 0.52.3", ] [[package]] @@ -4528,7 +4536,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.3", ] [[package]] @@ -4546,7 +4554,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.3", ] [[package]] @@ -4566,17 +4574,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.3", + "windows_aarch64_msvc 0.52.3", + "windows_i686_gnu 0.52.3", + "windows_i686_msvc 0.52.3", + "windows_x86_64_gnu 0.52.3", + "windows_x86_64_gnullvm 0.52.3", + "windows_x86_64_msvc 0.52.3", ] [[package]] @@ -4587,9 +4595,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" [[package]] name = "windows_aarch64_msvc" @@ -4599,9 +4607,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" [[package]] name = "windows_i686_gnu" @@ -4611,9 +4619,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" [[package]] name = "windows_i686_msvc" @@ -4623,9 +4631,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" [[package]] name = "windows_x86_64_gnu" @@ -4635,9 +4643,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" [[package]] name = "windows_x86_64_gnullvm" @@ -4647,9 +4655,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" [[package]] name = "windows_x86_64_msvc" @@ -4659,15 +4667,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" [[package]] name = "winnow" -version = "0.5.31" +version = "0.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a4882e6b134d6c28953a387571f1acdd3496830d5e36c5e3a1075580ea641c" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" dependencies = [ "memchr", ] @@ -4690,9 +4698,9 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "xxhash-rust" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" +checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" [[package]] name = "zerocopy" @@ -4711,7 +4719,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.46", + "syn 2.0.50", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 457a52e39b52..80c173e21a6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,7 @@ simdutf8 = "0.1.4" smartstring = "1" sqlparser = "0.39" streaming-iterator = "0.1.9" +strength_reduce = "0.2" strum_macros = "0.25" thiserror = "1" tokio = "1.26" @@ -80,6 +81,7 @@ url = "2.4" version_check = "0.9.4" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" +uuid = { version = "1.7.0", features = ["v4"] } polars = { version = "0.37.0", path = "crates/polars", default-features = false } polars-compute = { version = "0.37.0", path = "crates/polars-compute", default-features = false } @@ -110,6 +112,7 @@ default-features = false features = [ "compute_aggregate", "compute_arithmetics", + "compute_bitwise", "compute_boolean", "compute_boolean_kleene", "compute_cast", diff --git a/Makefile b/Makefile index 2b5079ced6d1..6e0ad8a149f2 100644 --- a/Makefile +++ b/Makefile @@ -19,90 +19,95 @@ FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS .PHONY: requirements requirements: .venv ## Install/refresh Python project requirements - $(VENV_BIN)/python -m pip install --upgrade pip - $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-dev.txt - $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-lint.txt - $(VENV_BIN)/pip install --upgrade -r py-polars/docs/requirements-docs.txt - $(VENV_BIN)/pip install --upgrade -r docs/requirements.txt + @unset CONDA_PREFIX \ + && $(VENV_BIN)/python -m pip install --upgrade uv \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/requirements-dev.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/requirements-lint.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/docs/requirements-docs.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r docs/requirements.txt .PHONY: build build: .venv ## Compile and install Python Polars for development - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt build-debug-opt: .venv ## Compile and install Python Polars with minimal optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --profile opt-dev \ + @unset CONDA_PREFIX \ + && maturin develop -m py-polars/Cargo.toml --profile opt-dev \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt-subset build-debug-opt-subset: .venv ## Compile and install Python Polars with minimal optimizations turned on and no default features - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --no-default-features --profile opt-dev \ $(FILTER_PIP_WARNINGS) .PHONY: build-opt build-opt: .venv ## Compile and install Python Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --profile debug-release \ $(FILTER_PIP_WARNINGS) .PHONY: build-release build-release: .venv ## Compile and install a faster Python Polars binary with full optimizations - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --release \ $(FILTER_PIP_WARNINGS) .PHONY: build-native build-native: .venv ## Same as build, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml -- -C target-cpu=native \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt-native build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --profile opt-dev -- -C target-cpu=native \ $(FILTER_PIP_WARNINGS) .PHONY: build-opt-native build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --profile debug-release -- -C target-cpu=native \ $(FILTER_PIP_WARNINGS) .PHONY: build-release-native build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ + @unset CONDA_PREFIX \ && maturin develop -m py-polars/Cargo.toml --release -- -C target-cpu=native \ $(FILTER_PIP_WARNINGS) + +.PHONY: check +check: ## Run cargo check with all features + cargo clippy --workspace --all-targets --all-features + .PHONY: clippy clippy: ## Run clippy with all features - cargo clippy --workspace --all-targets --all-features --locked -- -D warnings + cargo clippy --workspace --all-targets --all-features --locked -- -D warnings -D clippy::dbg_macro .PHONY: clippy-default clippy-default: ## Run clippy with default features - cargo clippy --all-targets --locked -- -D warnings + cargo clippy --all-targets --locked -- -D warnings -D clippy::dbg_macro .PHONY: fmt fmt: ## Run autoformatting and linting - $(VENV_BIN)/ruff check . - $(VENV_BIN)/ruff format . + $(VENV_BIN)/ruff check + $(VENV_BIN)/ruff format cargo fmt --all dprint fmt - $(VENV_BIN)/typos . + $(VENV_BIN)/typos .PHONY: pre-commit pre-commit: fmt clippy clippy-default ## Run all code quality checks .PHONY: clean clean: ## Clean up caches and build artifacts + @rm -rf .ruff_cache/ @rm -rf .venv/ - @rm -rf target/ - @rm -f Cargo.lock @cargo clean @$(MAKE) -s -C py-polars/ $@ diff --git a/README.md b/README.md index 1c677f1e3f4e..2f7fbf20c38e 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@

- + Polars logo

- + crates.io Latest Release PyPi Latest Release @@ -40,12 +40,12 @@ - R | - User Guide + User guide | Discord

-## Polars: Blazingly fast DataFrames in Rust, Python, Node.js, R and SQL +## Polars: Blazingly fast DataFrames in Rust, Python, Node.js, R, and SQL Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Rust using [Apache Arrow Columnar Format](https://arrow.apache.org/docs/format/Columnar.html) as the memory model. @@ -55,10 +55,10 @@ Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Ru - SIMD - Query optimization - Powerful expression API -- Hybrid Streaming (larger than RAM datasets) +- Hybrid Streaming (larger-than-RAM datasets) - Rust | Python | NodeJS | R | ... -To learn more, read the [User Guide](https://docs.pola.rs/). +To learn more, read the [user guide](https://docs.pola.rs/). ## Python @@ -103,9 +103,9 @@ shape: (5, 8) ```python >>> df = pl.scan_ipc("file.arrow") ->>> # create a sql context, registering the frame as a table +>>> # create a SQL context, registering the frame as a table >>> sql = pl.SQLContext(my_table=df) ->>> # create a sql query to execute +>>> # create a SQL query to execute >>> query = """ ... SELECT sum(v1) as sum_v1, min(v2) as min_v2 FROM my_table ... WHERE id1 = 'id016' @@ -136,7 +136,7 @@ shape: (5, 8) SQL commands can also be run directly from your terminal using the Polars CLI: ```bash -# run an inline sql query +# run an inline SQL query > polars -c "SELECT sum(v1) as sum_v1, min(v2) as min_v2 FROM read_ipc('file.arrow') WHERE id1 = 'id016' LIMIT 10" # run interactively @@ -156,7 +156,7 @@ Refer to the [Polars CLI repository](https://github.com/pola-rs/polars-cli) for Polars is very fast. In fact, it is one of the best performing solutions available. See the results in [DuckDB's db-benchmark](https://duckdblabs.github.io/db-benchmark/). -In the [TPCH benchmarks](https://www.pola.rs/benchmarks.html) Polars is orders of magnitudes faster than pandas, dask, modin and vaex +In the [TPC-H benchmarks](https://www.pola.rs/benchmarks.html) Polars is orders of magnitude faster than pandas, dask, modin and vaex on full queries (including IO). ### Lightweight @@ -167,18 +167,18 @@ Polars is also very lightweight. It comes with zero required dependencies, and t - numpy: 104ms - pandas: 520ms -### Handles larger than RAM data +### Handles larger-than-RAM data -If you have data that does not fit into memory, polars lazy is able to process your query (or parts of your query) in a -streaming fashion, this drastically reduces memory requirements so you might be able to process your 250GB dataset on your -laptop. Collect with `collect(streaming=True)` to run the query streaming. (This might be a little slower, but -it is still very fast!) +If you have data that does not fit into memory, Polars' query engine is able to process your query (or parts of your query) in a streaming fashion. +This drastically reduces memory requirements, so you might be able to process your 250GB dataset on your laptop. +Collect with `collect(streaming=True)` to run the query streaming. +(This might be a little slower, but it is still very fast!) ## Setup ### Python -Install the latest polars version with: +Install the latest Polars version with: ```sh pip install polars @@ -210,10 +210,10 @@ pip install 'polars[numpy,pandas,pyarrow]' | openpyxl | Support for reading from Excel files with native types | | deltalake | Support for reading and writing Delta Lake Tables | | pyiceberg | Support for reading from Apache Iceberg tables | -| plot | Support for plot functions on Dataframes | -| timezone | Timezone support, only needed if are on Python<3.9 or you are on Windows | +| plot | Support for plot functions on DataFrames | +| timezone | Timezone support, only needed if you are on Python<3.9 or Windows | -Releases happen quite often (weekly / every few days) at the moment, so updating polars regularly to get the latest bugfixes / features might not be a bad idea. +Releases happen quite often (weekly / every few days) at the moment, so updating Polars regularly to get the latest bugfixes / features might not be a bad idea. ### Rust @@ -224,7 +224,7 @@ point to the `main` branch of this repo. polars = { git = "https://github.com/pola-rs/polars", rev = "" } ``` -Required Rust version `>=1.71`. +Requires Rust version `>=1.71`. ## Contributing @@ -232,7 +232,7 @@ Want to contribute? Read our [contribution guideline](/CONTRIBUTING.md). ## Python: compile Polars from source -If you want a bleeding edge release or maximal performance you should compile **Polars** from source. +If you want a bleeding edge release or maximal performance you should compile Polars from source. This can be done by going through the following steps in sequence: @@ -253,16 +253,16 @@ can `pip install polars` and `import polars`. ## Use custom Rust function in Python? -Extending Polars with UDFs compiled in Rust is easy. We expose pyo3 extensions for `DataFrame` and `Series` +Extending Polars with UDFs compiled in Rust is easy. We expose PyO3 extensions for `DataFrame` and `Series` data structures. See more in https://github.com/pola-rs/pyo3-polars. ## Going big... -Do you expect more than `2^32` ~4,2 billion rows? Compile polars with the `bigidx` feature flag. +Do you expect more than 2^32 (~4.2 billion) rows? Compile Polars with the `bigidx` feature flag. -Or for Python users install `pip install polars-u64-idx`. +Or for Python users, install `pip install polars-u64-idx`. -Don't use this unless you hit the row boundary as the default Polars is faster and consumes less memory. +Don't use this unless you hit the row boundary, as the default build of Polars is faster and consumes less memory. ## Legacy @@ -273,4 +273,4 @@ features. ## Sponsors -[](https://www.jetbrains.com) +[JetBrains logo](https://www.jetbrains.com) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000000..65df338d67b4 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,19 @@ +coverage: + status: + project: off + patch: off +ignore: + - crates/polars-arrow/src/io/flight/*.rs + - crates/polars-arrow/src/io/ipc/append/*.rs + - crates/polars-arrow/src/io/ipc/read/array/union.rs + - crates/polars-arrow/src/io/ipc/read/array/map.rs + - crates/polars-arrow/src/io/ipc/read/array/binary.rs + - crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/read/array/null.rs + - crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/write/serialize/union.rs + - crates/polars-arrow/src/io/ipc/write/serialize/map.rs + - crates/polars-arrow/src/array/union/*.rs + - crates/polars-arrow/src/array/map/*.rs + - crates/polars-arrow/src/array/fixed_size_binary/*.rs + diff --git a/crates/Makefile b/crates/Makefile index 8cb3ec2da6dd..6e4ded353458 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -14,11 +14,11 @@ check: ## Run cargo check with all features .PHONY: clippy clippy: ## Run clippy with all features - cargo clippy -p polars --all-features + cargo clippy -p polars --all-features -- -W clippy::dbg_macro .PHONY: clippy-default clippy-default: ## Run clippy with default features - cargo clippy -p polars + cargo clippy -p polars -- -W clippy::dbg_macro .PHONY: pre-commit pre-commit: fmt clippy clippy-default ## Run autoformatting and linting diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index dbd0910cc8ee..f19509b7b96a 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -63,7 +63,7 @@ async-stream = { version = "0.3.2", optional = true } avro-schema = { workspace = true, optional = true } # for division/remainder optimization at runtime -strength_reduce = { version = "0.2", optional = true } +strength_reduce = { workspace = true, optional = true } # For instruction multiversioning multiversion = { workspace = true, optional = true } @@ -78,7 +78,6 @@ arrow-data = { workspace = true, optional = true } arrow-schema = { workspace = true, optional = true } [dev-dependencies] -apache-avro = { version = "0.16", features = ["snappy"] } criterion = "0.5" crossbeam-channel = { workspace = true } doc-comment = "0.3" diff --git a/crates/polars-arrow/src/array/binary/data.rs b/crates/polars-arrow/src/array/binary/data.rs index 56835dec0c42..a45ebcca0621 100644 --- a/crates/polars-arrow/src/array/binary/data.rs +++ b/crates/polars-arrow/src/array/binary/data.rs @@ -15,7 +15,7 @@ impl Arrow2Arrow for BinaryArray { ]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } @@ -29,7 +29,7 @@ impl Arrow2Arrow for BinaryArray { let buffers = data.buffers(); - // Safety: ArrayData is valid + // SAFETY: ArrayData is valid let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; offsets.slice(data.offset(), data.len() + 1); diff --git a/crates/polars-arrow/src/array/binary/from.rs b/crates/polars-arrow/src/array/binary/from.rs index 73df03531594..9ffac9827bb8 100644 --- a/crates/polars-arrow/src/array/binary/from.rs +++ b/crates/polars-arrow/src/array/binary/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{BinaryArray, MutableBinaryArray}; use crate::offset::Offset; diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs index 7031bc78245e..a0ba77030a83 100644 --- a/crates/polars-arrow/src/array/binary/mod.rs +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -161,6 +161,7 @@ impl BinaryArray { } /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -225,6 +226,7 @@ impl BinaryArray { /// Slices this [`BinaryArray`]. /// # Implementation /// This function is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -258,7 +260,7 @@ impl BinaryArray { use Either::*; if let Some(bitmap) = self.validity { match bitmap.into_mut() { - // Safety: invariants are preserved + // SAFETY: invariants are preserved Left(bitmap) => Left(BinaryArray::new( self.data_type, self.offsets, @@ -374,6 +376,7 @@ impl BinaryArray { } /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -398,6 +401,7 @@ impl BinaryArray { } /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/binary/mutable.rs b/crates/polars-arrow/src/array/binary/mutable.rs index 8b4aca205968..53a8ed32bb6f 100644 --- a/crates/polars-arrow/src/array/binary/mutable.rs +++ b/crates/polars-arrow/src/array/binary/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -236,6 +235,7 @@ impl> FromIterator> for MutableBinaryArray MutableBinaryArray { /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -262,6 +262,7 @@ impl MutableBinaryArray { } /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -283,6 +284,7 @@ impl MutableBinaryArray { } /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -325,7 +327,7 @@ impl MutableBinaryArray { P: AsRef<[u8]>, I: TrustedLen, { - // Safety: The iterator is `TrustedLen` + // SAFETY: The iterator is `TrustedLen` unsafe { self.extend_trusted_len_values_unchecked(iterator) } } @@ -349,6 +351,7 @@ impl MutableBinaryArray { /// Extends the [`MutableBinaryArray`] from an `iterator` of values of trusted length. /// This differs from `extend_trusted_len_unchecked` which accepts iterator of optional /// values. + /// /// # Safety /// The `iterator` must be [`TrustedLen`] #[inline] @@ -373,11 +376,12 @@ impl MutableBinaryArray { P: AsRef<[u8]>, I: TrustedLen>, { - // Safety: The iterator is `TrustedLen` + // SAFETY: The iterator is `TrustedLen` unsafe { self.extend_trusted_len_unchecked(iterator) } } /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + /// /// # Safety /// The `iterator` must be [`TrustedLen`] #[inline] diff --git a/crates/polars-arrow/src/array/binary/mutable_values.rs b/crates/polars-arrow/src/array/binary/mutable_values.rs index d6c8c969f058..613cbb0aba9e 100644 --- a/crates/polars-arrow/src/array/binary/mutable_values.rs +++ b/crates/polars-arrow/src/array/binary/mutable_values.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -163,6 +162,7 @@ impl MutableBinaryValuesArray { } /// Returns the value of the element at index `i`. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -266,6 +266,7 @@ impl MutableBinaryValuesArray { } /// Extends [`MutableBinaryValuesArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -289,6 +290,7 @@ impl MutableBinaryValuesArray { } /// Returns a new [`MutableBinaryValuesArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index eb16c473e714..86413d91e122 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -178,6 +178,7 @@ impl BinaryViewArrayGeneric { } /// Create a new BinaryViewArray but initialize a statistics compute. + /// /// # Safety /// The caller must ensure the invariants pub unsafe fn new_unchecked_unknown_md( @@ -267,6 +268,7 @@ impl BinaryViewArrayGeneric { } /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -373,6 +375,10 @@ impl BinaryViewArrayGeneric { mutable.freeze().with_validity(self.validity) } + pub fn is_sliced(&self) -> bool { + self.views.as_ptr() != self.views.storage_ptr() + } + pub fn maybe_gc(self) -> Self { const GC_MINIMUM_SAVINGS: usize = 16 * 1024; // At least 16 KiB. @@ -418,7 +424,7 @@ impl BinaryViewArray { /// Validate the underlying bytes on UTF-8. pub fn validate_utf8(&self) -> PolarsResult<()> { // SAFETY: views are correct - unsafe { validate_utf8_only(&self.views, &self.buffers) } + unsafe { validate_utf8_only(&self.views, &self.buffers, &self.buffers) } } /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`]. diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs index 4d62ff592c87..26a3a2faa930 100644 --- a/crates/polars-arrow/src/array/binview/mutable.rs +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -99,6 +99,11 @@ impl MutableBinaryViewArray { &self.views } + #[inline] + pub fn completed_buffers(&self) -> &[Buffer] { + &self.completed_buffers + } + pub fn validity(&mut self) -> Option<&mut MutableBitmap> { self.validity.as_mut() } @@ -308,10 +313,13 @@ impl MutableBinaryViewArray { Self::from_iterator(slice.as_ref().iter().map(|opt_v| opt_v.as_ref())) } - fn finish_in_progress(&mut self) { + fn finish_in_progress(&mut self) -> bool { if !self.in_progress_buffer.is_empty() { self.completed_buffers .push(std::mem::take(&mut self.in_progress_buffer).into()); + true + } else { + false } } @@ -321,6 +329,7 @@ impl MutableBinaryViewArray { } /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -363,10 +372,22 @@ impl MutableBinaryViewArray { } impl MutableBinaryViewArray<[u8]> { - pub fn validate_utf8(&mut self) -> PolarsResult<()> { - self.finish_in_progress(); + pub fn validate_utf8(&mut self, buffer_offset: usize, views_offset: usize) -> PolarsResult<()> { + // Finish the in progress as it might be required for validation. + let pushed = self.finish_in_progress(); // views are correct - unsafe { validate_utf8_only(&self.views, &self.completed_buffers) } + unsafe { + validate_utf8_only( + &self.views[views_offset..], + &self.completed_buffers[buffer_offset..], + &self.completed_buffers, + )? + } + // Restore in-progress buffer as we don't want to get too small buffers + if let (true, Some(last)) = (pushed, self.completed_buffers.pop()) { + self.in_progress_buffer = last.into_mut().right().unwrap(); + } + Ok(()) } } diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 34e7d799d3ea..921489c380a3 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -193,25 +193,51 @@ pub(super) fn validate_utf8_view(views: &[View], buffers: &[Buffer]) -> Pola /// The views and buffers must uphold the invariants of BinaryView otherwise we will go OOB. pub(super) unsafe fn validate_utf8_only( views: &[View], - buffers: &[Buffer], + buffers_to_check: &[Buffer], + all_buffers: &[Buffer], ) -> PolarsResult<()> { - for view in views { - let len = view.length; - if len <= 12 { + // If we have no buffers, we don't have to branch. + if all_buffers.is_empty() { + for view in views { + let len = view.length; validate_utf8( view.to_le_bytes() .get_unchecked_release(4..4 + len as usize), )?; - } else { - let buffer_idx = view.buffer_idx; - let offset = view.offset; - let data = buffers.get_unchecked_release(buffer_idx as usize); - - let start = offset as usize; - let end = start + len as usize; - let b = &data.as_slice().get_unchecked_release(start..end); - validate_utf8(b)?; - }; + } + return Ok(()); + } + + // Fast path if all buffers are ascii + if buffers_to_check.iter().all(|buf| buf.is_ascii()) { + for view in views { + let len = view.length; + if len <= 12 { + validate_utf8( + view.to_le_bytes() + .get_unchecked_release(4..4 + len as usize), + )?; + } + } + } else { + for view in views { + let len = view.length; + if len <= 12 { + validate_utf8( + view.to_le_bytes() + .get_unchecked_release(4..4 + len as usize), + )?; + } else { + let buffer_idx = view.buffer_idx; + let offset = view.offset; + let data = all_buffers.get_unchecked_release(buffer_idx as usize); + + let start = offset as usize; + let end = start + len as usize; + let b = &data.as_slice().get_unchecked_release(start..end); + validate_utf8(b)?; + }; + } } Ok(()) diff --git a/crates/polars-arrow/src/array/boolean/data.rs b/crates/polars-arrow/src/array/boolean/data.rs index e23038687d3e..f472348a0407 100644 --- a/crates/polars-arrow/src/array/boolean/data.rs +++ b/crates/polars-arrow/src/array/boolean/data.rs @@ -15,7 +15,7 @@ impl Arrow2Arrow for BooleanArray { .buffers(vec![buffer.into_inner().into_inner()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/boolean/from.rs b/crates/polars-arrow/src/array/boolean/from.rs index 81a5395ccc06..07553d78b737 100644 --- a/crates/polars-arrow/src/array/boolean/from.rs +++ b/crates/polars-arrow/src/array/boolean/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{BooleanArray, MutableBooleanArray}; impl]>> From

for BooleanArray { diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs index eb3401ee8d83..0737dcfece55 100644 --- a/crates/polars-arrow/src/array/boolean/mod.rs +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -136,6 +136,7 @@ impl BooleanArray { } /// Returns the element at index `i` as bool + /// /// # Safety /// Caller must be sure that `i < self.len()` #[inline] @@ -173,6 +174,7 @@ impl BooleanArray { /// Slices this [`BooleanArray`]. /// # Implementation /// This operation is `O(1)` as it amounts to increase two ref counts. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -279,6 +281,7 @@ impl BooleanArray { /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -298,6 +301,7 @@ impl BooleanArray { /// Creates a [`BooleanArray`] from an iterator of trusted length. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -321,6 +325,7 @@ impl BooleanArray { } /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/boolean/mutable.rs b/crates/polars-arrow/src/array/boolean/mutable.rs index 50b50574a7a2..fd3b4a1989d7 100644 --- a/crates/polars-arrow/src/array/boolean/mutable.rs +++ b/crates/polars-arrow/src/array/boolean/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -136,12 +135,13 @@ impl MutableBooleanArray { where I: TrustedLen, { - // Safety: `I` is `TrustedLen` + // SAFETY: `I` is `TrustedLen` unsafe { self.extend_trusted_len_values_unchecked(iterator) } } /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. /// This differs from `extend_trusted_len_unchecked`, which accepts in iterator of optional values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -167,11 +167,12 @@ impl MutableBooleanArray { P: std::borrow::Borrow, I: TrustedLen>, { - // Safety: `I` is `TrustedLen` + // SAFETY: `I` is `TrustedLen` unsafe { self.extend_trusted_len_unchecked(iterator) } } /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -255,6 +256,7 @@ impl MutableBooleanArray { /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -276,6 +278,7 @@ impl MutableBooleanArray { /// Creates a [`BooleanArray`] from an iterator of trusted length. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -297,11 +300,12 @@ impl MutableBooleanArray { P: std::borrow::Borrow, I: TrustedLen>, { - // Safety: `I` is `TrustedLen` + // SAFETY: `I` is `TrustedLen` unsafe { Self::from_trusted_len_iter_unchecked(iterator) } } /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -331,7 +335,7 @@ impl MutableBooleanArray { P: std::borrow::Borrow, I: TrustedLen, E>>, { - // Safety: `I` is `TrustedLen` + // SAFETY: `I` is `TrustedLen` unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } } @@ -555,7 +559,7 @@ impl TryExtendFromSelf for MutableBooleanArray { extend_validity(self.len(), &mut self.validity, &other.validity); let slice = other.values.as_slice(); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() unsafe { self.values .extend_from_slice_unchecked(slice, 0, other.values.len()); diff --git a/crates/polars-arrow/src/array/dictionary/data.rs b/crates/polars-arrow/src/array/dictionary/data.rs index 29eb1da057b3..e7159e4bfff2 100644 --- a/crates/polars-arrow/src/array/dictionary/data.rs +++ b/crates/polars-arrow/src/array/dictionary/data.rs @@ -13,7 +13,7 @@ impl Arrow2Arrow for DictionaryArray { .data_type(self.data_type.clone().into()) .child_data(vec![to_data(self.values.as_ref())]); - // Safety: Dictionary is valid + // SAFETY: Dictionary is valid unsafe { builder.build_unchecked() } } @@ -35,7 +35,7 @@ impl Arrow2Arrow for DictionaryArray { .len(data.len()) .nulls(data.nulls().cloned()); - // Safety: Dictionary is valid + // SAFETY: Dictionary is valid let key_data = unsafe { key_builder.build_unchecked() }; let keys = PrimitiveArray::from_data(&key_data); let values = from_data(&data.child_data()[0]); diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index df59504c0b98..6947c9e071c7 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -37,6 +37,7 @@ pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom + H const KEY_TYPE: IntegerType; /// Represents this key as a `usize`. + /// /// # Safety /// The caller _must_ have checked that the value can be casted to `usize`. #[inline] @@ -146,7 +147,7 @@ impl DictionaryArray { if keys.null_count() != keys.len() { if K::always_fits_usize() { - // safety: we just checked that conversion to `usize` always + // SAFETY: we just checked that conversion to `usize` always // succeeds unsafe { check_indexes_unchecked(keys.values(), values.len()) }?; } else { @@ -178,6 +179,7 @@ impl DictionaryArray { /// * the `data_type`'s logical type is not a `DictionaryArray` /// * the `data_type`'s keys is not compatible with `keys` /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// /// # Safety /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` pub unsafe fn try_new_unchecked( @@ -292,6 +294,7 @@ impl DictionaryArray { } /// Slices this [`DictionaryArray`]. + /// /// # Safety /// Safe iff `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -340,14 +343,14 @@ impl DictionaryArray { /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize` #[inline] pub fn keys_values_iter(&self) -> impl TrustedLen + Clone + '_ { - // safety - invariant of the struct + // SAFETY: invariant of the struct self.keys.values_iter().map(|x| unsafe { x.as_usize() }) } /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize` #[inline] pub fn keys_iter(&self) -> impl TrustedLen> + Clone + '_ { - // safety - invariant of the struct + // SAFETY: invariant of the struct self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() })) } @@ -356,7 +359,7 @@ impl DictionaryArray { /// This function panics iff `index >= self.len()` #[inline] pub fn key_value(&self, index: usize) -> usize { - // safety - invariant of the struct + // SAFETY: invariant of the struct unsafe { self.keys.values()[index].as_usize() } } @@ -374,7 +377,7 @@ impl DictionaryArray { /// This function panics iff `index >= self.len()` #[inline] pub fn value(&self, index: usize) -> Box { - // safety - invariant of this struct + // SAFETY: invariant of this struct let index = unsafe { self.keys.value(index).as_usize() }; new_scalar(self.values.as_ref(), index) } diff --git a/crates/polars-arrow/src/array/dictionary/mutable.rs b/crates/polars-arrow/src/array/dictionary/mutable.rs index 9c6e783dd28e..d55ba6484443 100644 --- a/crates/polars-arrow/src/array/dictionary/mutable.rs +++ b/crates/polars-arrow/src/array/dictionary/mutable.rs @@ -21,7 +21,7 @@ pub struct MutableDictionaryArray { impl From> for DictionaryArray { fn from(other: MutableDictionaryArray) -> Self { - // Safety - the invariant of this struct ensures that this is up-held + // SAFETY: the invariant of this struct ensures that this is up-held unsafe { DictionaryArray::::try_new_unchecked( other.data_type, diff --git a/crates/polars-arrow/src/array/dictionary/value_map.rs b/crates/polars-arrow/src/array/dictionary/value_map.rs index 2be9a7ca1047..5b6bdb9528ba 100644 --- a/crates/polars-arrow/src/array/dictionary/value_map.rs +++ b/crates/polars-arrow/src/array/dictionary/value_map.rs @@ -81,15 +81,16 @@ impl ValueMap { ); for index in 0..values.len() { let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; - // safety: we only iterate within bounds + // SAFETY: we only iterate within bounds let value = unsafe { values.value_unchecked_at(index) }; let hash = ahash_hash(value.borrow()); - match map.raw_entry_mut().from_hash(hash, |item| { - // safety: invariant of the struct, it's always in bounds since we maintain it + let entry = map.raw_entry_mut().from_hash(hash, |item| { + // SAFETY: invariant of the struct, it's always in bounds since we maintain it let stored_value = unsafe { values.value_unchecked_at(item.key.as_usize()) }; stored_value.borrow() == value.borrow() - }) { + }); + match entry { RawEntryMut::Occupied(_) => { polars_bail!(InvalidOperation: "duplicate value in dictionary values array") }, @@ -133,26 +134,25 @@ impl ValueMap { M::Type: Eq + Hash, { let hash = ahash_hash(value.as_indexed()); - Ok( - match self.map.raw_entry_mut().from_hash(hash, |item| { - // safety: we've already checked (the inverse) when we pushed it, so it should be ok? - let index = unsafe { item.key.as_usize() }; - // safety: invariant of the struct, it's always in bounds since we maintain it - let stored_value = unsafe { self.values.value_unchecked_at(index) }; - stored_value.borrow() == value.as_indexed() - }) { - RawEntryMut::Occupied(entry) => entry.key().key, - RawEntryMut::Vacant(entry) => { - let index = self.values.len(); - let key = - K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; - entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); // NB: don't use .insert() here! - push(&mut self.values, value)?; - debug_assert_eq!(self.values.len(), index + 1); - key - }, + let entry = self.map.raw_entry_mut().from_hash(hash, |item| { + // SAFETY: we've already checked (the inverse) when we pushed it, so it should be ok? + let index = unsafe { item.key.as_usize() }; + // SAFETY: invariant of the struct, it's always in bounds since we maintain it + let stored_value = unsafe { self.values.value_unchecked_at(index) }; + stored_value.borrow() == value.as_indexed() + }); + let out = match entry { + RawEntryMut::Occupied(entry) => entry.key().key, + RawEntryMut::Vacant(entry) => { + let index = self.values.len(); + let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; + entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); // NB: don't use .insert() here! + push(&mut self.values, value)?; + debug_assert_eq!(self.values.len(), index + 1); + key }, - ) + }; + Ok(out) } pub fn shrink_to_fit(&mut self) { diff --git a/crates/polars-arrow/src/array/ffi.rs b/crates/polars-arrow/src/array/ffi.rs index e1dd62488b70..9806eac25e97 100644 --- a/crates/polars-arrow/src/array/ffi.rs +++ b/crates/polars-arrow/src/array/ffi.rs @@ -26,6 +26,7 @@ pub(crate) unsafe trait ToFfi { /// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). pub(crate) trait FromFfi: Sized { /// Convert itself from FFI. + /// /// # Safety /// This function is intrinsically `unsafe` as it requires the FFI to be made according /// to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) diff --git a/crates/polars-arrow/src/array/fixed_size_binary/data.rs b/crates/polars-arrow/src/array/fixed_size_binary/data.rs index a0662d7d6555..f99822eb0fbb 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/data.rs @@ -13,7 +13,7 @@ impl Arrow2Arrow for FixedSizeBinaryArray { .buffers(vec![self.values.clone().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs index 70e421e746ad..e439aac214aa 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -105,6 +105,7 @@ impl FixedSizeBinaryArray { /// Slices this [`FixedSizeBinaryArray`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -151,6 +152,7 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as &str + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -264,20 +266,3 @@ impl FixedSizeBinaryArray { MutableFixedSizeBinaryArray::from(slice).into() } } - -pub trait FixedSizeBinaryValues { - fn values(&self) -> &[u8]; - fn size(&self) -> usize; -} - -impl FixedSizeBinaryValues for FixedSizeBinaryArray { - #[inline] - fn values(&self) -> &[u8] { - &self.values - } - - #[inline] - fn size(&self) -> usize { - self.size - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs index aba3904bef4e..8f81ce86f6d8 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; -use super::{FixedSizeBinaryArray, FixedSizeBinaryValues}; +use super::FixedSizeBinaryArray; use crate::array::physical_binary::extend_validity; use crate::array::{Array, MutableArray, TryExtendFromSelf}; use crate::bitmap::MutableBitmap; @@ -200,6 +200,7 @@ impl MutableFixedSizeBinaryArray { } /// Returns the element at index `i` as `&[u8]` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -289,18 +290,6 @@ impl MutableArray for MutableFixedSizeBinaryArray { } } -impl FixedSizeBinaryValues for MutableFixedSizeBinaryArray { - #[inline] - fn values(&self) -> &[u8] { - &self.values - } - - #[inline] - fn size(&self) -> usize { - self.size - } -} - impl PartialEq for MutableFixedSizeBinaryArray { fn eq(&self, other: &Self) -> bool { self.iter().eq(other.iter()) diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs index 5a0b902e9b51..f98fa452c6ea 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/data.rs @@ -12,7 +12,7 @@ impl Arrow2Arrow for FixedSizeListArray { .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(vec![to_data(self.values.as_ref())]); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 612e134a5a5d..2cefb2e8ddaf 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -111,6 +111,7 @@ impl FixedSizeListArray { /// Slices this [`FixedSizeListArray`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -156,6 +157,7 @@ impl FixedSizeListArray { } /// Returns the `Vec` at position `i`. + /// /// # Safety /// Caller must ensure that `i < self.len()` #[inline] diff --git a/crates/polars-arrow/src/array/growable/binview.rs b/crates/polars-arrow/src/array/growable/binview.rs index 76cf75e8bfa7..d167a09a8fe8 100644 --- a/crates/polars-arrow/src/array/growable/binview.rs +++ b/crates/polars-arrow/src/array/growable/binview.rs @@ -1,5 +1,10 @@ +use std::hash::{Hash, Hasher}; use std::sync::Arc; +use polars_utils::aliases::PlIndexSet; +use polars_utils::slice::GetSaferUnchecked; +use polars_utils::unwrap::UnwrapUncheckedRelease; + use super::Growable; use crate::array::binview::{BinaryViewArrayGeneric, View, ViewType}; use crate::array::growable::utils::{extend_validity, prepare_validity}; @@ -8,14 +13,35 @@ use crate::bitmap::MutableBitmap; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; +struct BufferKey<'a> { + inner: &'a Buffer, +} + +impl Hash for BufferKey<'_> { + fn hash(&self, state: &mut H) { + state.write_u64(self.inner.as_ptr() as u64) + } +} + +impl PartialEq for BufferKey<'_> { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.inner.as_ptr() == other.inner.as_ptr() + } +} + +impl Eq for BufferKey<'_> {} + /// Concrete [`Growable`] for the [`BinaryArray`]. pub struct GrowableBinaryViewArray<'a, T: ViewType + ?Sized> { arrays: Vec<&'a BinaryViewArrayGeneric>, data_type: ArrowDataType, validity: Option, views: Vec, - buffers: Vec>, - buffers_idx_offsets: Vec, + // We need to use a set/hashmap to deduplicate + // A growable can be called with many chunks from self. + // See: #14201 + buffers: PlIndexSet>, total_bytes_len: usize, total_buffer_len: usize, } @@ -37,21 +63,16 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { use_validity = true; }; - let mut cum_sum = 0; - let cum_offset = arrays - .iter() - .map(|binview| { - let out = cum_sum; - cum_sum += binview.data_buffers().len() as u32; - out - }) - .collect::>(); - let buffers = arrays .iter() - .flat_map(|array| array.data_buffers().as_ref()) - .cloned() - .collect::>(); + .flat_map(|array| { + array + .data_buffers() + .as_ref() + .iter() + .map(|buf| BufferKey { inner: buf }) + }) + .collect::>(); let total_buffer_len = arrays .iter() .map(|arr| arr.data_buffers().len()) @@ -63,7 +84,6 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { validity: prepare_validity(use_validity, capacity), views: Vec::with_capacity(capacity), buffers, - buffers_idx_offsets: cum_offset, total_bytes_len: 0, total_buffer_len, } @@ -77,7 +97,12 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { BinaryViewArrayGeneric::::new_unchecked( self.data_type.clone(), views.into(), - Arc::from(buffers), + Arc::from( + buffers + .into_iter() + .map(|buf| buf.inner.clone()) + .collect::>(), + ), validity.map(|v| v.into()), self.total_bytes_len, self.total_buffer_len, @@ -90,6 +115,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { /// doesn't check bounds pub unsafe fn extend_unchecked(&mut self, index: usize, start: usize, len: usize) { let array = *self.arrays.get_unchecked(index); + let local_buffers = array.data_buffers(); extend_validity(&mut self.validity, array, start, len); @@ -102,8 +128,11 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { self.total_bytes_len += len; if len > 12 { - let buffer_idx = *self.buffers_idx_offsets.get_unchecked(index); - view.buffer_idx += buffer_idx; + let buffer = local_buffers.get_unchecked_release(view.buffer_idx as usize); + let key = BufferKey { inner: buffer }; + let idx = self.buffers.get_full(&key).unwrap_unchecked_release().0; + + view.buffer_idx = idx as u32; } view })); @@ -111,6 +140,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { #[inline] /// Ignores the buffers and doesn't update the view. This is only correct in a filter. + /// /// # Safety /// doesn't check bounds pub unsafe fn extend_unchecked_no_buffers(&mut self, index: usize, start: usize, len: usize) { @@ -163,7 +193,12 @@ impl<'a, T: ViewType + ?Sized> From> for BinaryVi BinaryViewArrayGeneric::::new_unchecked( val.data_type, val.views.into(), - Arc::from(val.buffers), + Arc::from( + val.buffers + .into_iter() + .map(|buf| buf.inner.clone()) + .collect::>(), + ), val.validity.map(|v| v.into()), val.total_bytes_len, val.total_buffer_len, diff --git a/crates/polars-arrow/src/array/growable/boolean.rs b/crates/polars-arrow/src/array/growable/boolean.rs index e293d0051ca8..ea18791a804d 100644 --- a/crates/polars-arrow/src/array/growable/boolean.rs +++ b/crates/polars-arrow/src/array/growable/boolean.rs @@ -57,7 +57,7 @@ impl<'a> Growable<'a> for GrowableBoolean<'a> { let values = array.values(); let (slice, offset, _) = values.as_slice(); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() unsafe { self.values .extend_from_slice_unchecked(slice, start + offset, len); diff --git a/crates/polars-arrow/src/array/growable/dictionary.rs b/crates/polars-arrow/src/array/growable/dictionary.rs index 3c08b1cd65d9..dd2dbc01fde4 100644 --- a/crates/polars-arrow/src/array/growable/dictionary.rs +++ b/crates/polars-arrow/src/array/growable/dictionary.rs @@ -82,7 +82,7 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { validity.map(|v| v.into()), ); - // Safety - the invariant of this struct ensures that this is up-held + // SAFETY: the invariant of this struct ensures that this is up-held unsafe { DictionaryArray::::try_new_unchecked( self.data_type.clone(), diff --git a/crates/polars-arrow/src/array/growable/mod.rs b/crates/polars-arrow/src/array/growable/mod.rs index aea9cdd8789e..ca8fc87a5a86 100644 --- a/crates/polars-arrow/src/array/growable/mod.rs +++ b/crates/polars-arrow/src/array/growable/mod.rs @@ -1,8 +1,6 @@ //! Contains the trait [`Growable`] and corresponding concreate implementations, one per concrete array, //! that offer the ability to create a new [`Array`] out of slices of existing [`Array`]s. -use std::sync::Arc; - use crate::array::*; use crate::datatypes::*; @@ -37,11 +35,13 @@ mod utils; pub trait Growable<'a> { /// Extends this [`Growable`] with elements from the bounded [`Array`] at index `index` from /// a slice starting at `start` and length `len`. + /// /// # Safety /// Doesn't do any bound checks unsafe fn extend(&mut self, index: usize, start: usize, len: usize); /// Extends this [`Growable`] with null elements, disregarding the bound arrays + /// /// # Safety /// Doesn't do any bound checks fn extend_validity(&mut self, additional: usize); diff --git a/crates/polars-arrow/src/array/growable/utils.rs b/crates/polars-arrow/src/array/growable/utils.rs index 7357f661b199..7cb4b667a5c1 100644 --- a/crates/polars-arrow/src/array/growable/utils.rs +++ b/crates/polars-arrow/src/array/growable/utils.rs @@ -38,7 +38,7 @@ pub(super) fn extend_validity( Some(validity) => { debug_assert!(start + len <= validity.len()); let (slice, offset, _) = validity.as_slice(); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() unsafe { mutable_validity.extend_from_slice_unchecked(slice, start + offset, len); } diff --git a/crates/polars-arrow/src/array/indexable.rs b/crates/polars-arrow/src/array/indexable.rs index d3f466722aa6..b4f455ab00c4 100644 --- a/crates/polars-arrow/src/array/indexable.rs +++ b/crates/polars-arrow/src/array/indexable.rs @@ -22,6 +22,7 @@ pub trait Indexable { fn value_at(&self, index: usize) -> Self::Value<'_>; /// Returns the element at index `i`. + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] diff --git a/crates/polars-arrow/src/array/list/data.rs b/crates/polars-arrow/src/array/list/data.rs index 6f3424c96ce6..212778a05abb 100644 --- a/crates/polars-arrow/src/array/list/data.rs +++ b/crates/polars-arrow/src/array/list/data.rs @@ -14,7 +14,7 @@ impl Arrow2Arrow for ListArray { .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(vec![to_data(self.values.as_ref())]); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs index d53747e7f228..6c2934fa1061 100644 --- a/crates/polars-arrow/src/array/list/mod.rs +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -114,6 +114,7 @@ impl ListArray { } /// Slices this [`ListArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -144,20 +145,21 @@ impl ListArray { #[inline] pub fn value(&self, i: usize) -> Box { assert!(i < self.len()); - // Safety: invariant of this function + // SAFETY: invariant of this function unsafe { self.value_unchecked(i) } } /// Returns the element at index `i` as &str + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> Box { - // safety: the invariant of the function + // SAFETY: the invariant of the function let (start, end) = self.offsets.start_end_unchecked(i); let length = end - start; - // safety: the invariant of the struct + // SAFETY: the invariant of the struct self.values.sliced_unchecked(start, length) } diff --git a/crates/polars-arrow/src/array/map/data.rs b/crates/polars-arrow/src/array/map/data.rs index cb8862a4df3d..8eb586e05f4c 100644 --- a/crates/polars-arrow/src/array/map/data.rs +++ b/crates/polars-arrow/src/array/map/data.rs @@ -14,7 +14,7 @@ impl Arrow2Arrow for MapArray { .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(vec![to_data(self.field.as_ref())]); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/map/iterator.rs b/crates/polars-arrow/src/array/map/iterator.rs index f424e91b8043..558405ddc8de 100644 --- a/crates/polars-arrow/src/array/map/iterator.rs +++ b/crates/polars-arrow/src/array/map/iterator.rs @@ -32,7 +32,7 @@ impl<'a> Iterator for MapValuesIter<'a> { } let old = self.index; self.index += 1; - // Safety: + // SAFETY: // self.end is maximized by the length of the array Some(unsafe { self.array.value_unchecked(old) }) } @@ -52,7 +52,7 @@ impl<'a> DoubleEndedIterator for MapValuesIter<'a> { None } else { self.end -= 1; - // Safety: + // SAFETY: // self.end is maximized by the length of the array Some(unsafe { self.array.value_unchecked(self.end) }) } diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs index d057192ef612..c6ebfc353a06 100644 --- a/crates/polars-arrow/src/array/map/mod.rs +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -111,6 +111,7 @@ impl MapArray { } /// Returns a slice of this [`MapArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. #[inline] @@ -168,6 +169,7 @@ impl MapArray { } /// Returns the element at index `i`. + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index 24845b33bc89..b2a109f5af39 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -75,6 +75,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { } /// Returns whether slot `i` is null. + /// /// # Safety /// The caller must ensure `i < self.len()` #[inline] @@ -103,6 +104,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// Slices the [`Array`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` unsafe fn slice_unchecked(&mut self, offset: usize, length: usize); @@ -123,6 +125,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// # Implementation /// This operation is `O(1)` over `len`, as it amounts to increase two ref counts /// and moving the struct to the heap. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` #[must_use] @@ -143,15 +146,6 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { dyn_clone::clone_trait_object!(Array); -/// A trait describing an array with a backing store that can be preallocated to -/// a given size. -pub(crate) trait Container { - /// Create this array with a given capacity. - fn with_capacity(capacity: usize) -> Self - where - Self: Sized; -} - /// A trait describing a mutable array; i.e. an array whose values can be changed. /// Mutable arrays cannot be cloned but can be mutated in place, /// thereby making them useful to perform numeric operations without allocations. @@ -496,6 +490,7 @@ macro_rules! impl_sliced { /// Returns this array sliced. /// # Implementation /// This function is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -530,6 +525,12 @@ macro_rules! impl_mut_validity { } self.validity = validity; } + + /// Takes the validity of this array, leaving it without a validity mask. + #[inline] + pub fn take_validity(&mut self) -> Option { + self.validity.take() + } } } @@ -740,6 +741,7 @@ pub trait TryPush { /// A trait describing the ability of a struct to receive new items. pub trait PushUnchecked { /// Push a new element that holds the invariants of the struct. + /// /// # Safety /// The items must uphold the invariants of the struct /// Read the specific implementation of the trait to understand what these are. diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 82269e3c2066..753700abba16 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -62,6 +62,7 @@ impl NullArray { } /// Returns a slice of the [`NullArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. pub unsafe fn slice_unchecked(&mut self, _offset: usize, length: usize) { @@ -187,7 +188,7 @@ mod arrow { pub fn to_data(&self) -> ArrayData { let builder = ArrayDataBuilder::new(arrow_schema::DataType::Null).len(self.len()); - // Safety: safe by construction + // SAFETY: safe by construction unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/physical_binary.rs b/crates/polars-arrow/src/array/physical_binary.rs index 36e4ecf52d35..d92085d11d25 100644 --- a/crates/polars-arrow/src/array/physical_binary.rs +++ b/crates/polars-arrow/src/array/physical_binary.rs @@ -219,7 +219,7 @@ pub(crate) fn extend_validity( if let Some(other) = other { if let Some(validity) = validity { let slice = other.as_slice(); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() unsafe { validity.extend_from_slice_unchecked(slice, 0, other.len()) } } else { let mut new_validity = MutableBitmap::from_len_set(length); diff --git a/crates/polars-arrow/src/array/primitive/data.rs b/crates/polars-arrow/src/array/primitive/data.rs index d4879f796812..1a32b230f54f 100644 --- a/crates/polars-arrow/src/array/primitive/data.rs +++ b/crates/polars-arrow/src/array/primitive/data.rs @@ -14,7 +14,7 @@ impl Arrow2Arrow for PrimitiveArray { .buffers(vec![self.values.clone().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/primitive/from_natural.rs b/crates/polars-arrow/src/array/primitive/from_natural.rs index 0530c748af7e..a70259a8eeff 100644 --- a/crates/polars-arrow/src/array/primitive/from_natural.rs +++ b/crates/polars-arrow/src/array/primitive/from_natural.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{MutablePrimitiveArray, PrimitiveArray}; use crate::types::NativeType; diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs index c85cbc8420b3..8f70d233baeb 100644 --- a/crates/polars-arrow/src/array/primitive/mod.rs +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -208,6 +208,7 @@ impl PrimitiveArray { /// Returns the value at index `i`. /// The value on null slots is undetermined (it can be anything). + /// /// # Safety /// Caller must be sure that `i < self.len()` #[inline] @@ -243,6 +244,7 @@ impl PrimitiveArray { /// Slices this [`PrimitiveArray`] by an offset and length. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -341,7 +343,7 @@ impl PrimitiveArray { /// This function returns a [`MutablePrimitiveArray`] (via [`std::sync::Arc::get_mut`]) iff both values /// and validity have not been cloned / are unique references to their underlying vectors. /// - /// This function is primarily used to re-use memory regions. + /// This function is primarily used to reuse memory regions. #[must_use] pub fn into_mut(self) -> Either> { use Either::*; @@ -420,6 +422,7 @@ impl PrimitiveArray { } /// Creates a new [`PrimitiveArray`] from an iterator over values + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -433,6 +436,7 @@ impl PrimitiveArray { } /// Creates a [`PrimitiveArray`] from an iterator of optional values. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -448,6 +452,37 @@ impl PrimitiveArray { pub fn new(data_type: ArrowDataType, values: Buffer, validity: Option) -> Self { Self::try_new(data_type, values, validity).unwrap() } + + /// Transmute this PrimitiveArray into another PrimitiveArray. + /// + /// T and U must have the same size and alignment. + pub fn transmute(self) -> PrimitiveArray { + let PrimitiveArray { + values, validity, .. + } = self; + + // SAFETY: this is fine, we checked size and alignment, and NativeType + // is always Pod. + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + assert_eq!(std::mem::align_of::(), std::mem::align_of::()); + let new_values = unsafe { std::mem::transmute::, Buffer>(values) }; + PrimitiveArray::new(U::PRIMITIVE.into(), new_values, validity) + } + + /// Fills this entire array with the given value, leaving the validity mask intact. + /// + /// Reuses the memory of the PrimitiveArray if possible. + pub fn fill_with(mut self, value: T) -> Self { + if let Some(values) = self.get_mut_values() { + for x in values.iter_mut() { + *x = value; + } + self + } else { + let values = vec![value; self.len()]; + Self::new(T::PRIMITIVE.into(), values.into(), self.validity) + } + } } impl Array for PrimitiveArray { diff --git a/crates/polars-arrow/src/array/primitive/mutable.rs b/crates/polars-arrow/src/array/primitive/mutable.rs index 3c7a8489b77e..eab3498d8ed9 100644 --- a/crates/polars-arrow/src/array/primitive/mutable.rs +++ b/crates/polars-arrow/src/array/primitive/mutable.rs @@ -1,8 +1,6 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::PolarsResult; -use polars_utils::total_ord::TotalOrdWrap; use super::{check, PrimitiveArray}; use crate::array::physical_binary::extend_validity; @@ -195,6 +193,7 @@ impl MutablePrimitiveArray { } /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -224,6 +223,7 @@ impl MutablePrimitiveArray { /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. /// This differs from `extend_trusted_len_unchecked` which accepts in iterator of optional values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -312,7 +312,7 @@ impl MutablePrimitiveArray { /// Panics iff index is larger than `self.len()`. pub fn set(&mut self, index: usize, value: Option) { assert!(index < self.len()); - // Safety: + // SAFETY: // we just checked bounds unsafe { self.set_unchecked(index, value) } } @@ -320,6 +320,7 @@ impl MutablePrimitiveArray { /// Sets position `index` to `value`. /// Note that if it is the first time a null appears in this array, /// this initializes the validity bitmap (`O(N)`). + /// /// # Safety /// Caller must ensure `index < self.len()` pub unsafe fn set_unchecked(&mut self, index: usize, value: Option) { @@ -364,14 +365,6 @@ impl Extend> for MutablePrimitiveArray { } } -impl Extend>> for MutablePrimitiveArray { - fn extend>>>(&mut self, iter: I) { - let iter = iter.into_iter(); - self.reserve(iter.size_hint().0); - iter.for_each(|x| self.push(x.map(|x| x.0))) - } -} - impl TryExtend> for MutablePrimitiveArray { /// This is infallible and is implemented for consistency with all other types fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { @@ -448,6 +441,7 @@ impl MutablePrimitiveArray { } /// Creates a [`MutablePrimitiveArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. `size_hint().1` correctly reports its length. @@ -477,6 +471,7 @@ impl MutablePrimitiveArray { } /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -525,6 +520,7 @@ impl MutablePrimitiveArray { } /// Creates a new [`MutablePrimitiveArray`] from an iterator over values + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/specification.rs b/crates/polars-arrow/src/array/specification.rs index 72fb80ae2202..e5759144bf44 100644 --- a/crates/polars-arrow/src/array/specification.rs +++ b/crates/polars-arrow/src/array/specification.rs @@ -93,7 +93,7 @@ pub fn try_check_utf8(offsets: &[O], values: &[u8]) -> PolarsResult<( for start in starts { let start = start.to_usize(); - // Safety: `try_check_offsets_bounds` just checked for bounds + // SAFETY: `try_check_offsets_bounds` just checked for bounds let b = *unsafe { values.get_unchecked(start) }; // A valid code-point iff it does not start with 0b10xxxxxx diff --git a/crates/polars-arrow/src/array/static_array.rs b/crates/polars-arrow/src/array/static_array.rs index 4ae0f44212ba..ac8fbc4cec32 100644 --- a/crates/polars-arrow/src/array/static_array.rs +++ b/crates/polars-arrow/src/array/static_array.rs @@ -25,7 +25,7 @@ pub trait StaticArray: type ZeroableValueT<'a>: Zeroable + From> where Self: 'a; - type ValueIterT<'a>: Iterator> + TrustedLen + type ValueIterT<'a>: DoubleEndedIterator> + TrustedLen + Send + Sync where Self: 'a; diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs index 27f86fec1f5b..9413b0a16778 100644 --- a/crates/polars-arrow/src/array/static_array_collect.rs +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -552,7 +552,7 @@ impl ArrayFromIter> for BinaryViewArray { // fn try_arr_from_iter_trusted(iter: I) -> Result } -/// We use this to re-use the binary collect implementation for strings. +/// We use this to reuse the binary collect implementation for strings. /// # Safety /// The array must be valid UTF-8. unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { @@ -807,12 +807,14 @@ impl ArrayFromIter> for BooleanArray { // as Rust considers that AsRef for Option<&dyn Array> could be implemented. trait AsArray { fn as_array(&self) -> &dyn Array; + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box; // Prevents unnecessary re-boxing. } impl AsArray for Box { fn as_array(&self) -> &dyn Array { self.as_ref() } + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box { self } @@ -821,6 +823,7 @@ impl<'a> AsArray for &'a dyn Array { fn as_array(&self) -> &'a dyn Array { *self } + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box { self.to_boxed() } diff --git a/crates/polars-arrow/src/array/struct_/data.rs b/crates/polars-arrow/src/array/struct_/data.rs index b96dc4ffe28b..4dfcb0010a73 100644 --- a/crates/polars-arrow/src/array/struct_/data.rs +++ b/crates/polars-arrow/src/array/struct_/data.rs @@ -12,7 +12,7 @@ impl Arrow2Arrow for StructArray { .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/struct_/iterator.rs b/crates/polars-arrow/src/array/struct_/iterator.rs index cb8e6aafbb09..4e89af3a6a7f 100644 --- a/crates/polars-arrow/src/array/struct_/iterator.rs +++ b/crates/polars-arrow/src/array/struct_/iterator.rs @@ -31,7 +31,7 @@ impl<'a> Iterator for StructValueIter<'a> { let old = self.index; self.index += 1; - // Safety: + // SAFETY: // self.end is maximized by the length of the array Some( self.array @@ -58,7 +58,7 @@ impl<'a> DoubleEndedIterator for StructValueIter<'a> { } else { self.end -= 1; - // Safety: + // SAFETY: // self.end is maximized by the length of the array Some( self.array diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index 6f796ac18ac4..21d3247bbc85 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -178,6 +178,7 @@ impl StructArray { /// Slices this [`StructArray`]. /// # Implementation /// This operation is `O(F)` where `F` is the number of fields. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { diff --git a/crates/polars-arrow/src/array/union/data.rs b/crates/polars-arrow/src/array/union/data.rs index 8bbac00b6c6e..4303ab7b4356 100644 --- a/crates/polars-arrow/src/array/union/data.rs +++ b/crates/polars-arrow/src/array/union/data.rs @@ -25,7 +25,7 @@ impl Arrow2Arrow for UnionArray { ), }; - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs index 703a5f8d86e3..86d7cbed7397 100644 --- a/crates/polars-arrow/src/array/union/mod.rs +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -130,7 +130,7 @@ impl UnionArray { Some(hash) } else { - // Safety: every type in types is smaller than number of fields + // SAFETY: every type in types is smaller than number of fields let mut is_valid = true; for &type_ in types.iter() { if type_ < 0 || type_ >= number_of_fields { @@ -240,6 +240,7 @@ impl UnionArray { /// Returns a slice of this [`UnionArray`]. /// # Implementation /// This operation is `O(F)` where `F` is the number of fields. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -296,20 +297,21 @@ impl UnionArray { /// Returns the index and slot of the field to select from `self.fields`. /// The first value is guaranteed to be `< self.fields().len()` + /// /// # Safety /// This function is safe iff `index < self.len`. #[inline] pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) { debug_assert!(index < self.len()); - // Safety: assumption of the function + // SAFETY: assumption of the function let type_ = unsafe { *self.types.get_unchecked(index) }; - // Safety: assumption of the struct + // SAFETY: assumption of the struct let type_ = self .map .as_ref() .map(|map| unsafe { *map.get_unchecked(type_ as usize) }) .unwrap_or(type_ as usize); - // Safety: assumption of the function + // SAFETY: assumption of the function let index = self.field_slot_unchecked(index); (type_, index) } @@ -323,12 +325,13 @@ impl UnionArray { } /// Returns the slot `index` as a [`Scalar`]. + /// /// # Safety /// This function is safe iff `i < self.len`. pub unsafe fn value_unchecked(&self, index: usize) -> Box { debug_assert!(index < self.len()); let (type_, index) = self.index_unchecked(index); - // Safety: assumption of the struct + // SAFETY: assumption of the struct debug_assert!(type_ < self.fields.len()); let field = self.fields.get_unchecked(type_).as_ref(); new_scalar(field, index) diff --git a/crates/polars-arrow/src/array/utf8/data.rs b/crates/polars-arrow/src/array/utf8/data.rs index 16674c969372..577a43677c05 100644 --- a/crates/polars-arrow/src/array/utf8/data.rs +++ b/crates/polars-arrow/src/array/utf8/data.rs @@ -15,7 +15,7 @@ impl Arrow2Arrow for Utf8Array { ]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); - // Safety: Array is valid + // SAFETY: Array is valid unsafe { builder.build_unchecked() } } @@ -28,7 +28,7 @@ impl Arrow2Arrow for Utf8Array { let buffers = data.buffers(); - // Safety: ArrayData is valid + // SAFETY: ArrayData is valid let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; offsets.slice(data.offset(), data.len() + 1); diff --git a/crates/polars-arrow/src/array/utf8/from.rs b/crates/polars-arrow/src/array/utf8/from.rs index c1dcaf09b10d..6f90bac99495 100644 --- a/crates/polars-arrow/src/array/utf8/from.rs +++ b/crates/polars-arrow/src/array/utf8/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{MutableUtf8Array, Utf8Array}; use crate::offset::Offset; diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs index c25234c3cb2a..218e71323abf 100644 --- a/crates/polars-arrow/src/array/utf8/mod.rs +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -156,6 +156,7 @@ impl Utf8Array { } /// Returns the value of the element at index `i`, ignoring the array's validity. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -223,6 +224,7 @@ impl Utf8Array { /// Slices this [`Utf8Array`]. /// # Implementation /// This function is `O(1)` + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -256,7 +258,7 @@ impl Utf8Array { use Either::*; if let Some(bitmap) = self.validity { match bitmap.into_mut() { - // Safety: invariants are preserved + // SAFETY: invariants are preserved Left(bitmap) => Left(unsafe { Utf8Array::new_unchecked( self.data_type, @@ -267,7 +269,7 @@ impl Utf8Array { }), Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { (Left(values), Left(offsets)) => { - // Safety: invariants are preserved + // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, @@ -278,7 +280,7 @@ impl Utf8Array { }) }, (Left(values), Right(offsets)) => { - // Safety: invariants are preserved + // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, @@ -289,7 +291,7 @@ impl Utf8Array { }) }, (Right(values), Left(offsets)) => { - // Safety: invariants are preserved + // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, @@ -362,6 +364,7 @@ impl Utf8Array { /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// /// # Safety /// This function is unsound iff: /// * The `values` between two consecutive `offsets` are not valid utf8 @@ -430,6 +433,7 @@ impl Utf8Array { } /// Creates a [`Utf8Array`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -453,6 +457,7 @@ impl Utf8Array { } /// Creates a [`Utf8Array`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/utf8/mutable.rs b/crates/polars-arrow/src/array/utf8/mutable.rs index a67d2e00b4c2..ef9a5e8527b7 100644 --- a/crates/polars-arrow/src/array/utf8/mutable.rs +++ b/crates/polars-arrow/src/array/utf8/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -75,6 +74,7 @@ impl MutableUtf8Array { } /// Create a [`MutableUtf8Array`] out of low-end APIs. + /// /// # Safety /// The caller must ensure that every value between offsets is a valid utf8. /// # Panics @@ -145,14 +145,13 @@ impl MutableUtf8Array { } /// Returns the value of the element at index `i`, ignoring the array's validity. - /// # Safety - /// This function is safe iff `i < self.len`. #[inline] pub fn value(&self, i: usize) -> &str { self.values.value(i) } /// Returns the value of the element at index `i`, ignoring the array's validity. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -329,6 +328,7 @@ impl MutableUtf8Array { /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional /// values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -357,6 +357,7 @@ impl MutableUtf8Array { } /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -376,6 +377,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -404,6 +406,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an iterator of trusted length of `&str`. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -440,6 +443,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/utf8/mutable_values.rs b/crates/polars-arrow/src/array/utf8/mutable_values.rs index 92fd10661785..ce3c2f71f20c 100644 --- a/crates/polars-arrow/src/array/utf8/mutable_values.rs +++ b/crates/polars-arrow/src/array/utf8/mutable_values.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -23,7 +22,7 @@ pub struct MutableUtf8ValuesArray { impl From> for Utf8Array { fn from(other: MutableUtf8ValuesArray) -> Self { - // Safety: + // SAFETY: // `MutableUtf8ValuesArray` has the same invariants as `Utf8Array` and thus // `Utf8Array` can be safely created from `MutableUtf8ValuesArray` without checks. unsafe { @@ -39,7 +38,7 @@ impl From> for Utf8Array { impl From> for MutableUtf8Array { fn from(other: MutableUtf8ValuesArray) -> Self { - // Safety: + // SAFETY: // `MutableUtf8ValuesArray` has the same invariants as `MutableUtf8Array` unsafe { MutableUtf8Array::::new_unchecked(other.data_type, other.offsets, other.values, None) @@ -95,6 +94,7 @@ impl MutableUtf8ValuesArray { /// This function does not panic iff: /// * The last offset is equal to the values' length. /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// /// # Safety /// This function is safe iff: /// * the offsets are monotonically increasing @@ -187,7 +187,7 @@ impl MutableUtf8ValuesArray { self.offsets.pop()?; let start = self.offsets.last().to_usize(); let value = self.values.split_off(start); - // Safety: utf8 is validated on initialization + // SAFETY: utf8 is validated on initialization Some(unsafe { String::from_utf8_unchecked(value) }) } @@ -201,6 +201,7 @@ impl MutableUtf8ValuesArray { } /// Returns the value of the element at index `i`. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -309,6 +310,7 @@ impl MutableUtf8ValuesArray { } /// Extends [`MutableUtf8ValuesArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -333,6 +335,7 @@ impl MutableUtf8ValuesArray { } /// Returns a new [`MutableUtf8ValuesArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 53b5a71bc1b5..1e4615018fa4 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -190,6 +189,7 @@ impl Bitmap { } /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// /// # Safety /// The caller must ensure that `self.offset + offset + length <= self.len()` #[inline] @@ -246,6 +246,7 @@ impl Bitmap { } /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// /// # Safety /// The caller must ensure that `self.offset + offset + length <= self.len()` #[inline] @@ -264,6 +265,7 @@ impl Bitmap { } /// Unsafely returns whether the bit at position `i` is set. + /// /// # Safety /// Unsound iff `i >= self.len()`. #[inline] @@ -418,6 +420,7 @@ impl FromIterator for Bitmap { impl Bitmap { /// Creates a new [`Bitmap`] from an iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] @@ -440,6 +443,7 @@ impl Bitmap { } /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] @@ -493,7 +497,7 @@ impl From for arrow_buffer::buffer::NullBuffer { let null_count = value.unset_bits(); let buffer = crate::buffer::to_buffer(value.bytes); let buffer = arrow_buffer::buffer::BooleanBuffer::new(buffer, value.offset, value.length); - // Safety: null count is accurate + // SAFETY: null count is accurate unsafe { arrow_buffer::buffer::NullBuffer::new_unchecked(buffer, null_count) } } } diff --git a/crates/polars-arrow/src/bitmap/mutable.rs b/crates/polars-arrow/src/bitmap/mutable.rs index de6de7d42cbf..cd3eaa99b389 100644 --- a/crates/polars-arrow/src/bitmap/mutable.rs +++ b/crates/polars-arrow/src/bitmap/mutable.rs @@ -1,5 +1,4 @@ use std::hint::unreachable_unchecked; -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -221,6 +220,7 @@ impl MutableBitmap { } /// Pushes a new bit to the [`MutableBitmap`] + /// /// # Safety /// The caller must ensure that the [`MutableBitmap`] has sufficient capacity. #[inline] @@ -318,6 +318,7 @@ impl MutableBitmap { } /// Sets the position `index` to `value` + /// /// # Safety /// Caller must ensure that `index < self.len()` #[inline] @@ -524,11 +525,12 @@ impl MutableBitmap { /// Extends `self` from a [`TrustedLen`] iterator. #[inline] pub fn extend_from_trusted_len_iter>(&mut self, iterator: I) { - // safety: I: TrustedLen + // SAFETY: I: TrustedLen unsafe { self.extend_from_trusted_len_iter_unchecked(iterator) } } /// Extends `self` from an iterator of trusted len. + /// /// # Safety /// The caller must guarantee that the iterator has a trusted len. #[inline] @@ -577,6 +579,7 @@ impl MutableBitmap { } /// Creates a new [`MutableBitmap`] from an iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] @@ -597,7 +600,7 @@ impl MutableBitmap { where I: TrustedLen, { - // Safety: Iterator is `TrustedLen` + // SAFETY: Iterator is `TrustedLen` unsafe { Self::from_trusted_len_iter_unchecked(iterator) } } @@ -610,6 +613,7 @@ impl MutableBitmap { } /// Creates a new [`MutableBitmap`] from an falible iterator of booleans. + /// /// # Safety /// The caller must guarantee that the iterator is `TrustedLen`. pub unsafe fn try_from_trusted_len_iter_unchecked( @@ -697,6 +701,7 @@ impl MutableBitmap { /// # Implementation /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + /// /// # Safety /// Caller must ensure `offset + length <= slice.len() * 8` #[inline] @@ -729,7 +734,7 @@ impl MutableBitmap { #[inline] pub fn extend_from_slice(&mut self, slice: &[u8], offset: usize, length: usize) { assert!(offset + length <= slice.len() * 8); - // safety: invariant is asserted + // SAFETY: invariant is asserted unsafe { self.extend_from_slice_unchecked(slice, offset, length) } } @@ -737,7 +742,7 @@ impl MutableBitmap { #[inline] pub fn extend_from_bitmap(&mut self, bitmap: &Bitmap) { let (slice, offset, length) = bitmap.as_slice(); - // safety: bitmap.as_slice adheres to the invariant + // SAFETY: bitmap.as_slice adheres to the invariant unsafe { self.extend_from_slice_unchecked(slice, offset, length); } diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs index 4ab9d300ba02..7bc12e22898e 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::slice::ChunksExact; use super::{BitChunk, BitChunkIterExact}; diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs index 71f56a284274..8a1668a37d1f 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - mod chunks_exact; mod merge; diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs index 15d7b0935edc..158a0ed4f61d 100644 --- a/crates/polars-arrow/src/buffer/immutable.rs +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::Deref; use std::sync::Arc; use std::usize; @@ -113,18 +112,19 @@ impl Buffer { /// Returns the byte slice stored in this buffer #[inline] pub fn as_slice(&self) -> &[T] { - // Safety: + // SAFETY: // invariant of this struct `offset + length <= data.len()` debug_assert!(self.offset() + self.length <= self.storage.len()); unsafe { std::slice::from_raw_parts(self.ptr, self.length) } } /// Returns the byte slice stored in this buffer + /// /// # Safety /// `index` must be smaller than `len` #[inline] pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T { - // Safety: + // SAFETY: // invariant of this function debug_assert!(index < self.length); unsafe { &*self.ptr.add(index) } @@ -140,7 +140,7 @@ impl Buffer { offset + length <= self.len(), "the offset of the new Buffer cannot exceed the existing length" ); - // Safety: we just checked bounds + // SAFETY: we just checked bounds unsafe { self.sliced_unchecked(offset, length) } } @@ -153,12 +153,13 @@ impl Buffer { offset + length <= self.len(), "the offset of the new Buffer cannot exceed the existing length" ); - // Safety: we just checked bounds + // SAFETY: we just checked bounds unsafe { self.slice_unchecked(offset, length) } } /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. /// Doing so allows the same memory region to be shared between buffers. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] @@ -169,6 +170,7 @@ impl Buffer { } /// Slices this buffer starting at `offset`. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] diff --git a/crates/polars-arrow/src/buffer/mod.rs b/crates/polars-arrow/src/buffer/mod.rs index 9a66c19c5942..46ef0af62d98 100644 --- a/crates/polars-arrow/src/buffer/mod.rs +++ b/crates/polars-arrow/src/buffer/mod.rs @@ -29,6 +29,7 @@ impl Bytes { /// Takes ownership of an allocated memory region. /// # Panics /// This function panics if and only if pointer is not null + /// /// # Safety /// This function is safe if and only if `ptr` is valid for `length` /// # Implementation @@ -78,7 +79,7 @@ pub(crate) fn to_buffer( // This should never panic as ForeignVec pointer must be non-null let ptr = std::ptr::NonNull::new(value.as_ptr() as _).unwrap(); let len = value.len() * std::mem::size_of::(); - // Safety: allocation is guaranteed to be valid for `len` bytes + // SAFETY: allocation is guaranteed to be valid for `len` bytes unsafe { arrow_buffer::Buffer::from_custom_allocation(ptr, len, value) } } @@ -94,7 +95,7 @@ pub(crate) fn to_bytes(value: arrow_buffer::Buffer) let owner = crate::buffer::BytesAllocator::Arrow(value); - // Safety: slice is valid for len elements of T + // SAFETY: slice is valid for len elements of T unsafe { Bytes::from_foreign(ptr, len, owner) } } diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs index d78ed4d23f50..7fd68506421d 100644 --- a/crates/polars-arrow/src/compute/aggregate/memory.rs +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -1,8 +1,8 @@ use crate::array::*; use crate::bitmap::Bitmap; use crate::datatypes::PhysicalType; +pub use crate::types::PrimitiveType; use crate::{match_integer_type, with_match_primitive_type_full}; - fn validity_size(validity: Option<&Bitmap>) -> usize { validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) } @@ -48,6 +48,10 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { let array = array.as_any().downcast_ref::().unwrap(); array.values().as_slice().0.len() + validity_size(array.validity()) }, + Primitive(PrimitiveType::DaysMs) => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().len() * std::mem::size_of::() * 2 + validity_size(array.validity()) + }, Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array = array .as_any() diff --git a/crates/polars-arrow/src/compute/aggregate/simd/mod.rs b/crates/polars-arrow/src/compute/aggregate/simd/mod.rs index ea359fc592f5..010ba336fc37 100644 --- a/crates/polars-arrow/src/compute/aggregate/simd/mod.rs +++ b/crates/polars-arrow/src/compute/aggregate/simd/mod.rs @@ -40,7 +40,8 @@ macro_rules! simd_add { }; } -pub(super) use simd_add; +// #[cfg(not(feature = "simd"))] +// pub(super) use simd_add; simd_add!(i128x8, i128, 8, add); diff --git a/crates/polars-arrow/src/compute/aggregate/simd/native.rs b/crates/polars-arrow/src/compute/aggregate/simd/native.rs index 01382ecbbc8f..eb33878decbd 100644 --- a/crates/polars-arrow/src/compute/aggregate/simd/native.rs +++ b/crates/polars-arrow/src/compute/aggregate/simd/native.rs @@ -1,7 +1,6 @@ use std::ops::Add; use super::super::sum::Sum; -use super::simd_add; use crate::types::simd::*; simd_add!(u8x64, u8, 64, wrapping_add); diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs b/crates/polars-arrow/src/compute/arithmetics/basic/add.rs deleted file mode 100644 index ec941edc2381..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Definition of basic add operations with primitive arrays -use std::ops::Add; - -use super::NativeArithmetics; -use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; - -/// Adds two primitive arrays with the same type. -/// Panics if the sum of one pair of values overflows. -pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + Add, -{ - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) -} - -/// Adds a scalar T to a primitive array of type T. -/// Panics if the sum of the values overflows. -pub fn add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + Add, -{ - let rhs = *rhs; - unary(lhs, |a| a + rhs, lhs.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs b/crates/polars-arrow/src/compute/arithmetics/basic/div.rs deleted file mode 100644 index 9b5220b1b1ef..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! Definition of basic div operations with primitive arrays -use std::ops::Div; - -use num_traits::{CheckedDiv, NumCast}; -use strength_reduce::{ - StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, -}; - -use super::NativeArithmetics; -use crate::array::{Array, PrimitiveArray}; -use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; -use crate::compute::utils::check_same_len; -use crate::datatypes::PrimitiveType; - -/// Divides two primitive arrays with the same type. -/// Panics if the divisor is zero of one pair of values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::div; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[Some(10), Some(1), Some(6)]); -/// let b = Int32Array::from(&[Some(5), None, Some(6)]); -/// let result = div(&a, &b); -/// let expected = Int32Array::from(&[Some(2), None, Some(1)]); -/// assert_eq!(result, expected) -/// ``` -pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + Div, -{ - if rhs.null_count() == 0 { - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b) - } else { - check_same_len(lhs, rhs).unwrap(); - let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) { - (Some(l), Some(r)) => Some(*l / *r), - _ => None, - }); - - PrimitiveArray::from_trusted_len_iter(values).to(lhs.data_type().clone()) - } -} - -/// Checked division of two primitive arrays. If the result from the division -/// overflows, the result for the operation will change the validity array -/// making this operation None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_div; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); -/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); -/// let result = checked_div(&a, &b); -/// let expected = Int8Array::from(&[Some(-1i8), None]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + CheckedDiv, -{ - let op = move |a: T, b: T| a.checked_div(&b); - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Divide a primitive array of type T by a scalar T. -/// Panics if the divisor is zero. -pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + Div + NumCast, -{ - let rhs = *rhs; - match T::PRIMITIVE { - PrimitiveType::UInt64 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u64().unwrap(); - - let reduced_div = StrengthReducedU64::new(rhs); - let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt32 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u32().unwrap(); - - let reduced_div = StrengthReducedU32::new(rhs); - let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt16 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u16().unwrap(); - - let reduced_div = StrengthReducedU16::new(rhs); - - let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt8 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u8().unwrap(); - - let reduced_div = StrengthReducedU8::new(rhs); - let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - _ => unary(lhs, |a| a / rhs, lhs.data_type().clone()), - } -} - -/// Checked division of a primitive array of type T by a scalar T. If the -/// divisor is zero then the validity array is changed to None. -pub fn checked_div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + CheckedDiv, -{ - let rhs = *rhs; - let op = move |a: T| a.checked_div(&rhs); - - unary_checked(lhs, op, lhs.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs deleted file mode 100644 index 0b384f1767f7..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs +++ /dev/null @@ -1,99 +0,0 @@ -//! Contains arithmetic functions for [`PrimitiveArray`]s. -//! -//! Each operation has four variants, like the rest of Rust's ecosystem: -//! * usual, that [`panic!`]s on overflow -//! * `checked_*` that turns overflowings to `None` -//! * `overflowing_*` returning a [`Bitmap`](crate::bitmap::Bitmap) with items that overflow. -//! * `saturating_*` that saturates the result. -mod add; -pub use add::*; -mod div; -pub use div::*; -mod mul; -pub use mul::*; -mod rem; -pub use rem::*; -mod sub; -use std::ops::Neg; - -use num_traits::{CheckedNeg, WrappingNeg}; -pub use sub::*; - -use super::super::arity::{unary, unary_checked}; -use crate::array::PrimitiveArray; -use crate::types::NativeType; - -/// Trait describing a [`NativeType`] whose semantics of arithmetic in Arrow equals -/// the semantics in Rust. -/// A counter example is `i128`, that in arrow represents a decimal while in rust represents -/// a signed integer. -pub trait NativeArithmetics: NativeType {} -impl NativeArithmetics for u8 {} -impl NativeArithmetics for u16 {} -impl NativeArithmetics for u32 {} -impl NativeArithmetics for u64 {} -impl NativeArithmetics for i8 {} -impl NativeArithmetics for i16 {} -impl NativeArithmetics for i32 {} -impl NativeArithmetics for i64 {} -impl NativeArithmetics for i128 {} -impl NativeArithmetics for f32 {} -impl NativeArithmetics for f64 {} - -/// Negates values from array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::negate; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([None, Some(6), None, Some(7)]); -/// let result = negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); -/// assert_eq!(result, expected) -/// ``` -pub fn negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + Neg, -{ - unary(array, |a| -a, array.data_type().clone()) -} - -/// Checked negates values from array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_negate; -/// use polars_arrow::array::{Array, PrimitiveArray}; -/// -/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); -/// let result = checked_negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); -/// assert_eq!(result, expected); -/// assert!(!result.is_valid(2)) -/// ``` -pub fn checked_negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + CheckedNeg, -{ - unary_checked(array, |a| a.checked_neg(), array.data_type().clone()) -} - -/// Wrapping negates values from array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_negate; -/// use polars_arrow::array::{Array, PrimitiveArray}; -/// -/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); -/// let result = wrapping_negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), Some(i8::MIN), Some(-7)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + WrappingNeg, -{ - unary(array, |a| a.wrapping_neg(), array.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs deleted file mode 100644 index a1ed463f0195..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Definition of basic mul operations with primitive arrays -use std::ops::Mul; - -use super::NativeArithmetics; -use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; - -/// Multiplies two primitive arrays with the same type. -/// Panics if the multiplication of one pair of values overflows. -pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + Mul, -{ - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) -} - -/// Multiply a scalar T to a primitive array of type T. -/// Panics if the multiplication of the values overflows. -pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + Mul, -{ - let rhs = *rhs; - unary(lhs, |a| a * rhs, lhs.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs b/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs deleted file mode 100644 index 46eeb16cb8c6..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::ops::Rem; - -use num_traits::NumCast; -use strength_reduce::{ - StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, -}; - -use super::NativeArithmetics; -use crate::array::{Array, PrimitiveArray}; -use crate::compute::arity::{binary, unary}; -use crate::datatypes::PrimitiveType; - -/// Remainder of two primitive arrays with the same type. -/// Panics if the divisor is zero of one pair of values overflows. -pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + Rem, -{ - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) -} - -/// Remainder a primitive array of type T by a scalar T. -/// Panics if the divisor is zero. -pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + Rem + NumCast, -{ - let rhs = *rhs; - - match T::PRIMITIVE { - PrimitiveType::UInt64 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u64().unwrap(); - - let reduced_rem = StrengthReducedU64::new(rhs); - - // small hack to avoid a transmute of `PrimitiveArray` to `PrimitiveArray` - let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt32 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u32().unwrap(); - - let reduced_rem = StrengthReducedU32::new(rhs); - - let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); - // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt16 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u16().unwrap(); - - let reduced_rem = StrengthReducedU16::new(rhs); - - let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); - // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - PrimitiveType::UInt8 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.to_u8().unwrap(); - - let reduced_rem = StrengthReducedU8::new(rhs); - - let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); - // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` - (&r as &dyn Array) - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - }, - _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs b/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs deleted file mode 100644 index 33acb99b3ef6..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Definition of basic sub operations with primitive arrays -use std::ops::Sub; - -use super::NativeArithmetics; -use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; - -/// Subtracts two primitive arrays with the same type. -/// Panics if the subtraction of one pair of values overflows. -pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + Sub, -{ - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) -} - -/// Subtract a scalar T to a primitive array of type T. -/// Panics if the subtraction of the values overflows. -pub fn sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + Sub, -{ - let rhs = *rhs; - unary(lhs, |a| a - rhs, lhs.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/mod.rs b/crates/polars-arrow/src/compute/arithmetics/mod.rs deleted file mode 100644 index 38883ee044cf..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod basic; diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index e75be2d54d49..c7970fe6a051 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -156,7 +156,7 @@ fn fixed_size_to_offsets(values_len: usize, fixed_size: usize) -> Off .step_by(fixed_size) .map(|v| O::from_as_usize(v)) .collect(); - // Safety + // SAFETY: // * every element is `>= 0` // * element at position `i` is >= than element at position `i-1`. unsafe { Offsets::new_unchecked(offsets) } diff --git a/crates/polars-arrow/src/compute/cast/boolean_to.rs b/crates/polars-arrow/src/compute/cast/boolean_to.rs index ef07278d5171..c53e59629a8f 100644 --- a/crates/polars-arrow/src/compute/cast/boolean_to.rs +++ b/crates/polars-arrow/src/compute/cast/boolean_to.rs @@ -1,7 +1,7 @@ use polars_error::PolarsResult; -use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; -use crate::offset::Offset; +use super::{ArrayFromIter, BinaryViewArray, Utf8ViewArray}; +use crate::array::{Array, BooleanArray, PrimitiveArray}; use crate::types::NativeType; pub(super) fn boolean_to_primitive_dyn(array: &dyn Array) -> PolarsResult> @@ -26,24 +26,26 @@ where PrimitiveArray::::new(T::PRIMITIVE.into(), values.into(), from.validity().cloned()) } -/// Casts the [`BooleanArray`] to a [`Utf8Array`], casting trues to `"1"` and falses to `"0"` -pub fn boolean_to_utf8(from: &BooleanArray) -> Utf8Array { - let iter = from.values().iter().map(|x| if x { "1" } else { "0" }); - Utf8Array::from_trusted_len_values_iter(iter) +pub fn boolean_to_utf8view(from: &BooleanArray) -> Utf8ViewArray { + unsafe { boolean_to_binaryview(from).to_utf8view_unchecked() } } -pub(super) fn boolean_to_utf8_dyn(array: &dyn Array) -> PolarsResult> { +pub(super) fn boolean_to_utf8view_dyn(array: &dyn Array) -> PolarsResult> { let array = array.as_any().downcast_ref().unwrap(); - Ok(Box::new(boolean_to_utf8::(array))) + Ok(boolean_to_utf8view(array).boxed()) } /// Casts the [`BooleanArray`] to a [`BinaryArray`], casting trues to `"1"` and falses to `"0"` -pub fn boolean_to_binary(from: &BooleanArray) -> BinaryArray { - let iter = from.values().iter().map(|x| if x { b"1" } else { b"0" }); - BinaryArray::from_trusted_len_values_iter(iter) +pub fn boolean_to_binaryview(from: &BooleanArray) -> BinaryViewArray { + let iter = from.iter().map(|opt_b| match opt_b { + Some(true) => Some("true".as_bytes()), + Some(false) => Some("false".as_bytes()), + None => None, + }); + BinaryViewArray::arr_from_iter_trusted(iter) } -pub(super) fn boolean_to_binary_dyn(array: &dyn Array) -> PolarsResult> { +pub(super) fn boolean_to_binaryview_dyn(array: &dyn Array) -> PolarsResult> { let array = array.as_any().downcast_ref().unwrap(); - Ok(Box::new(boolean_to_binary::(array))) + Ok(boolean_to_binaryview(array).boxed()) } diff --git a/crates/polars-arrow/src/compute/cast/dictionary_to.rs b/crates/polars-arrow/src/compute/cast/dictionary_to.rs index d29dbe3fdd94..8ef67750dcdc 100644 --- a/crates/polars-arrow/src/compute/cast/dictionary_to.rs +++ b/crates/polars-arrow/src/compute/cast/dictionary_to.rs @@ -1,9 +1,8 @@ use polars_error::{polars_bail, PolarsResult}; use super::{primitive_as_primitive, primitive_to_primitive, CastOptions}; -use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::array::{Array, DictionaryArray, DictionaryKey}; use crate::compute::cast::cast; -use crate::compute::take::take_unchecked; use crate::datatypes::ArrowDataType; use crate::match_integer_type; @@ -16,7 +15,7 @@ macro_rules! key_cast { if cast_keys.null_count() > $keys.null_count() { polars_bail!(ComputeError: "overflow") } - // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` unsafe { DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone()) } @@ -93,7 +92,7 @@ where Box::new(values.data_type().clone()), is_ordered, ); - // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) } } } @@ -141,46 +140,12 @@ pub(super) fn dictionary_cast_dyn( // create the appropriate array type let to_key_type = (*to_keys_type).into(); - // Safety: + // SAFETY: // we return an error on overflow so the integers remain within bounds match_integer_type!(to_keys_type, |$T| { key_cast!(keys, values, array, &to_key_type, $T, to_type.clone()) }) }, - _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), + _ => unimplemented!(), } } - -// Unpack the dictionary -fn unpack_dictionary( - keys: &PrimitiveArray, - values: &dyn Array, - to_type: &ArrowDataType, - options: CastOptions, -) -> PolarsResult> -where - K: DictionaryKey + num_traits::NumCast, -{ - // attempt to cast the dict values to the target type - // use the take kernel to expand out the dictionary - let values = cast(values, to_type, options)?; - - // take requires first casting i32 - let indices = primitive_to_primitive::<_, i32>(keys, &ArrowDataType::Int32); - - Ok(unsafe { take_unchecked(values.as_ref(), &indices) }) -} - -/// Casts a [`DictionaryArray`] to its values' [`ArrowDataType`], also known as unpacking. -/// The resulting array has the same length. -pub fn dictionary_to_values(from: &DictionaryArray) -> Box -where - K: DictionaryKey + num_traits::NumCast, -{ - // take requires first casting i64 - let indices = primitive_to_primitive::<_, i64>(from.keys(), &ArrowDataType::Int64); - - // SAFETY: - // The dictionary guarantees that the keys are not out-of-bounds. - unsafe { take_unchecked(from.values().as_ref(), &indices) } -} diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 2c822b9c0897..015eac0606ea 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -78,6 +78,26 @@ macro_rules! primitive_dyn { }}; } +fn cast_struct( + array: &StructArray, + to_type: &ArrowDataType, + options: CastOptions, +) -> PolarsResult { + let values = array.values(); + let fields = StructArray::get_fields(to_type); + let new_values = values + .iter() + .zip(fields) + .map(|(arr, field)| cast(arr.as_ref(), field.data_type(), options)) + .collect::>>()?; + + Ok(StructArray::new( + to_type.clone(), + new_values, + array.validity().cloned(), + )) +} + fn cast_list( array: &ListArray, to_type: &ArrowDataType, @@ -134,7 +154,7 @@ fn cast_fixed_size_list_to_list( let offsets = (0..=fixed.len()) .map(|ix| O::from_as_usize(ix * fixed.size())) .collect::>(); - // Safety: offsets _are_ monotonically increasing + // SAFETY: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; Ok(ListArray::::new( @@ -205,7 +225,7 @@ fn cast_list_to_fixed_size_list( } } let take_values = unsafe { - crate::legacy::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze()) + crate::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze()) }; cast(take_values.as_ref(), inner.data_type(), options)? @@ -239,13 +259,14 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul /// * Fixed Size List to List: the underlying data type is cast /// * List to Fixed Size List: the offsets are checked for valid order, then the /// underlying type is cast. +/// * Struct to Struct: the underlying fields are cast. /// * PrimitiveArray to List: a list array with 1 value per slot is created /// * Date32 and Date64: precision lost when going to higher interval /// * Time32 and Time64: precision lost when going to higher interval /// * Timestamp and Date{32|64}: precision lost when going to higher interval /// * Temporal to/from backing primitive: zero-copy with data type change /// Unsupported Casts -/// * To or from `StructArray` +/// * non-`StructArray` to `StructArray` or `StructArray` to non-`StructArray` /// * List to primitive /// * Utf8 to boolean /// * Interval and duration @@ -265,6 +286,10 @@ pub fn cast( let as_options = options.with_wrapped(true); match (from_type, to_type) { (Null, _) | (_, Null) => Ok(new_null_array(to_type.clone(), array.len())), + (Struct(from_fd), Struct(to_fd)) => { + polars_ensure!(from_fd.len() == to_fd.len(), InvalidOperation: "Cannot cast struct with different number of fields."); + cast_struct(array.as_any().downcast_ref().unwrap(), to_type, options).map(|x| x.boxed()) + }, (Struct(_), _) | (_, Struct(_)) => polars_bail!(InvalidOperation: "Cannot cast from struct to other types" ), @@ -345,7 +370,7 @@ pub fn cast( let values = cast(array, &to.data_type, options)?; // create offsets, where if array.len() = 2, we have [0,1,2] let offsets = (0..=array.len() as i32).collect::>(); - // Safety: offsets _are_ monotonically increasing + // SAFETY: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); @@ -358,7 +383,7 @@ pub fn cast( let values = cast(array, &to.data_type, options)?; // create offsets, where if array.len() = 2, we have [0,1,2] let offsets = (0..=array.len() as i64).collect::>(); - // Safety: offsets _are_ monotonically increasing + // SAFETY: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; let list_array = ListArray::::new( @@ -437,8 +462,8 @@ pub fn cast( Int64 => boolean_to_primitive_dyn::(array), Float32 => boolean_to_primitive_dyn::(array), Float64 => boolean_to_primitive_dyn::(array), - LargeUtf8 => boolean_to_utf8_dyn::(array), - LargeBinary => boolean_to_binary_dyn::(array), + Utf8View => boolean_to_utf8view_dyn(array), + BinaryView => boolean_to_binaryview_dyn(array), _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", ), diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index 1522729e8f3f..d0d2056b70de 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -85,7 +85,7 @@ fn primitive_to_values_and_offsets( } values.set_len(offset); values.shrink_to_fit(); - // Safety: offsets _are_ monotonically increasing + // SAFETY: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; (values, offsets) diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index df827487620b..fadb6552beab 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -38,7 +38,7 @@ pub fn utf8_to_large_utf8(from: &Utf8Array) -> Utf8Array { let values = from.values().clone(); let offsets = from.offsets().into(); - // Safety: sound because `values` fulfills the same invariants as `from.values()` + // SAFETY: sound because `values` fulfills the same invariants as `from.values()` unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } } @@ -49,7 +49,7 @@ pub fn utf8_large_to_utf8(from: &Utf8Array) -> PolarsResult> let values = from.values().clone(); let offsets = from.offsets().try_into()?; - // Safety: sound because `values` fulfills the same invariants as `from.values()` + // SAFETY: sound because `values` fulfills the same invariants as `from.values()` Ok(unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) }) } @@ -58,7 +58,7 @@ pub fn utf8_to_binary( from: &Utf8Array, to_data_type: ArrowDataType, ) -> BinaryArray { - // Safety: erasure of an invariant is always safe + // SAFETY: erasure of an invariant is always safe unsafe { BinaryArray::::new( to_data_type, diff --git a/crates/polars-arrow/src/compute/mod.rs b/crates/polars-arrow/src/compute/mod.rs index d5c534cfa579..6dba6456d7f6 100644 --- a/crates/polars-arrow/src/compute/mod.rs +++ b/crates/polars-arrow/src/compute/mod.rs @@ -14,9 +14,6 @@ #[cfg(any(feature = "compute_aggregate", feature = "io_parquet"))] #[cfg_attr(docsrs, doc(cfg(feature = "compute_aggregate")))] pub mod aggregate; -#[cfg(feature = "compute_arithmetics")] -#[cfg_attr(docsrs, doc(cfg(feature = "compute_arithmetics")))] -pub mod arithmetics; pub mod arity; pub mod arity_assign; #[cfg(feature = "compute_bitwise")] diff --git a/crates/polars-arrow/src/compute/take/binary.rs b/crates/polars-arrow/src/compute/take/binary.rs index fa0a1ceb4b57..8d2b971ced8f 100644 --- a/crates/polars-arrow/src/compute/take/binary.rs +++ b/crates/polars-arrow/src/compute/take/binary.rs @@ -31,7 +31,7 @@ pub unsafe fn take_unchecked( let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { (false, false) => { - take_no_validity::(values.offsets(), values.values(), indices.values()) + take_no_validity_unchecked::(values.offsets(), values.values(), indices.values()) }, (true, false) => take_values_validity(values, indices.values()), (false, true) => take_indices_validity(values.offsets(), values.values(), indices), diff --git a/crates/polars-arrow/src/compute/take/binview.rs b/crates/polars-arrow/src/compute/take/binview.rs new file mode 100644 index 000000000000..65ff633a080a --- /dev/null +++ b/crates/polars-arrow/src/compute/take/binview.rs @@ -0,0 +1,22 @@ +use self::primitive::take_values_and_validity_unchecked; +use super::*; +use crate::array::BinaryViewArray; + +/// # Safety +/// No bound checks +pub(super) unsafe fn take_binview_unchecked( + arr: &BinaryViewArray, + indices: &IdxArr, +) -> BinaryViewArray { + let (views, validity) = + take_values_and_validity_unchecked(arr.views(), arr.validity(), indices); + + BinaryViewArray::new_unchecked_unknown_md( + arr.data_type().clone(), + views.into(), + arr.data_buffers().clone(), + validity, + Some(arr.total_buffer_len()), + ) + .maybe_gc() +} diff --git a/crates/polars-arrow/src/legacy/compute/take/bitmap.rs b/crates/polars-arrow/src/compute/take/bitmap.rs similarity index 100% rename from crates/polars-arrow/src/legacy/compute/take/bitmap.rs rename to crates/polars-arrow/src/compute/take/bitmap.rs diff --git a/crates/polars-arrow/src/compute/take/boolean.rs b/crates/polars-arrow/src/compute/take/boolean.rs index 45971d995c8a..049a3c4d5d9f 100644 --- a/crates/polars-arrow/src/compute/take/boolean.rs +++ b/crates/polars-arrow/src/compute/take/boolean.rs @@ -1,61 +1,42 @@ -use super::Index; +use super::bitmap::take_bitmap_unchecked; use crate::array::{Array, BooleanArray, PrimitiveArray}; use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::legacy::index::IdxSize; // take implementation when neither values nor indices contain nulls -unsafe fn take_no_validity(values: &Bitmap, indices: &[I]) -> (Bitmap, Option) { - let values = indices - .iter() - .map(|index| values.get_bit_unchecked(index.to_usize())); - let buffer = Bitmap::from_trusted_len_iter(values); - - (buffer, None) +unsafe fn take_no_validity(values: &Bitmap, indices: &[IdxSize]) -> (Bitmap, Option) { + (take_bitmap_unchecked(values, indices), None) } // take implementation when only values contain nulls -unsafe fn take_values_validity( +unsafe fn take_values_validity( values: &BooleanArray, - indices: &[I], + indices: &[IdxSize], ) -> (Bitmap, Option) { let validity_values = values.validity().unwrap(); - let validity = indices - .iter() - .map(|index| validity_values.get_bit_unchecked(index.to_usize())); - let validity = Bitmap::from_trusted_len_iter(validity); + let validity = take_bitmap_unchecked(validity_values, indices); let values_values = values.values(); - let values = indices - .iter() - .map(|index| values_values.get_bit_unchecked(index.to_usize())); - let buffer = Bitmap::from_trusted_len_iter(values); + let buffer = take_bitmap_unchecked(values_values, indices); (buffer, validity.into()) } // take implementation when only indices contain nulls -pub(super) unsafe fn take_indices_validity( +unsafe fn take_indices_validity( values: &Bitmap, - indices: &PrimitiveArray, + indices: &PrimitiveArray, ) -> (Bitmap, Option) { - let validity = indices.validity().unwrap(); - - let values = indices.values().iter().enumerate().map(|(i, index)| { - let index = index.to_usize(); - match values.get(index) { - Some(value) => value, - None => validity.get_bit_unchecked(i), - } - }); - - let buffer = Bitmap::from_trusted_len_iter(values); + // simply take all and copy the bitmap + let buffer = take_bitmap_unchecked(values, indices.values()); (buffer, indices.validity().cloned()) } // take implementation when both values and indices contain nulls -unsafe fn take_values_indices_validity( +unsafe fn take_values_indices_validity( values: &BooleanArray, - indices: &PrimitiveArray, + indices: &PrimitiveArray, ) -> (Bitmap, Option) { let mut validity = MutableBitmap::with_capacity(indices.len()); @@ -63,8 +44,9 @@ unsafe fn take_values_indices_validity( let values_values = values.values(); let values = indices.iter().map(|index| match index { - Some(index) => { - let index = index.to_usize(); + Some(&index) => { + let index = index as usize; + debug_assert!(index < values.len()); validity.push(values_validity.get_bit_unchecked(index)); values_values.get_bit_unchecked(index) }, @@ -78,9 +60,9 @@ unsafe fn take_values_indices_validity( } /// `take` implementation for boolean arrays -pub(super) unsafe fn take_unchecked( +pub unsafe fn take_unchecked( values: &BooleanArray, - indices: &PrimitiveArray, + indices: &PrimitiveArray, ) -> BooleanArray { let data_type = values.data_type().clone(); let indices_has_validity = indices.null_count() > 0; diff --git a/crates/polars-arrow/src/compute/take/dict.rs b/crates/polars-arrow/src/compute/take/dict.rs deleted file mode 100644 index e000aff57344..000000000000 --- a/crates/polars-arrow/src/compute/take/dict.rs +++ /dev/null @@ -1,44 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::primitive::take_unchecked as take_primitive; -use super::Index; -use crate::array::{DictionaryArray, DictionaryKey, PrimitiveArray}; - -/// `take` implementation for dictionary arrays -/// -/// applies `take` to the keys of the dictionary array and returns a new dictionary array -/// with the same dictionary values and reordered keys -pub(super) unsafe fn take_unchecked( - values: &DictionaryArray, - indices: &PrimitiveArray, -) -> DictionaryArray -where - K: DictionaryKey, - I: Index, -{ - let keys = take_primitive::(values.keys(), indices); - // safety - this operation takes a subset of keys and thus preserves the dictionary's invariant - unsafe { - DictionaryArray::::try_new_unchecked( - values.data_type().clone(), - keys, - values.values().clone(), - ) - .unwrap() - } -} diff --git a/crates/polars-arrow/src/compute/take/generic_binary.rs b/crates/polars-arrow/src/compute/take/generic_binary.rs index ac8d9d004ded..74a52134beed 100644 --- a/crates/polars-arrow/src/compute/take/generic_binary.rs +++ b/crates/polars-arrow/src/compute/take/generic_binary.rs @@ -1,4 +1,6 @@ use polars_utils::slice::GetSaferUnchecked; +use polars_utils::unwrap::UnwrapUncheckedRelease; +use polars_utils::vec::{CapacityByFactor, PushUnchecked}; use super::Index; use crate::array::{GenericBinaryArray, PrimitiveArray}; @@ -6,6 +8,23 @@ use crate::bitmap::{Bitmap, MutableBitmap}; use crate::buffer::Buffer; use crate::offset::{Offset, Offsets, OffsetsBuffer}; +fn create_offsets, O: Offset>( + lengths: I, + idx_len: usize, +) -> OffsetsBuffer { + let mut length_so_far = O::default(); + let mut offsets = Vec::with_capacity(idx_len + 1); + offsets.push(length_so_far); + + for len in lengths { + unsafe { + length_so_far += O::from_usize(len).unwrap_unchecked_release(); + offsets.push_unchecked(length_so_far) + }; + } + unsafe { Offsets::new_unchecked(offsets).into() } +} + pub(super) unsafe fn take_values( length: O, starts: &[O], @@ -26,21 +45,23 @@ pub(super) unsafe fn take_values( } // take implementation when neither values nor indices contain nulls -pub fn take_no_validity( +pub(super) unsafe fn take_no_validity_unchecked( offsets: &OffsetsBuffer, values: &[u8], indices: &[I], ) -> (OffsetsBuffer, Buffer, Option) { - let mut buffer = Vec::::new(); + let values_len = offsets.last().to_usize(); + let fraction_estimate = indices.len() as f64 / offsets.len() as f64 + 0.3; + let mut buffer = Vec::::with_capacity_by_factor(values_len, fraction_estimate); + let lengths = indices.iter().map(|index| index.to_usize()).map(|index| { - let (start, end) = offsets.start_end(index); - // todo: remove this bound check - buffer.extend_from_slice(&values[start..end]); + let (start, end) = offsets.start_end_unchecked(index); + buffer.extend_from_slice(values.get_unchecked(start..end)); end - start }); - let offsets = Offsets::try_from_lengths(lengths).expect(""); + let offsets = create_offsets(lengths, indices.len()); - (offsets.into(), buffer.into(), None) + (offsets, buffer.into(), None) } // take implementation when only values contain nulls @@ -51,7 +72,7 @@ pub(super) unsafe fn take_values_validity::with_capacity(indices.len()); - let offsets = indices.iter().map(|index| { + let lengths = indices.iter().map(|index| { let index = index.to_usize(); let start = *offsets.get_unchecked(index); length += *offsets.get_unchecked(index + 1) - start; - starts.push(start); - length + starts.push_unchecked(start); + length.to_usize() }); - let offsets = std::iter::once(O::default()) - .chain(offsets) - .collect::>(); - // Safety: by construction offsets are monotonically increasing - let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); - + let offsets = create_offsets(lengths, indices.len()); let buffer = take_values(length, starts.as_slice(), &offsets, values_values); (offsets, buffer, validity.into()) @@ -89,23 +105,19 @@ pub(super) unsafe fn take_indices_validity( let offsets = offsets.buffer(); let mut starts = Vec::::with_capacity(indices.len()); - let offsets = indices.values().iter().map(|index| { + let lengths = indices.values().iter().map(|index| { let index = index.to_usize(); match offsets.get(index + 1) { Some(&next) => { let start = *offsets.get_unchecked(index); length += next - start; - starts.push(start); + starts.push_unchecked(start); }, - None => starts.push(O::default()), + None => starts.push_unchecked(O::default()), }; - length + length.to_usize() }); - let offsets = std::iter::once(O::default()) - .chain(offsets) - .collect::>(); - // Safety: by construction offsets are monotonically increasing - let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + let offsets = create_offsets(lengths, indices.len()); let buffer = take_values(length, &starts, &offsets, values); @@ -125,7 +137,7 @@ pub(super) unsafe fn take_values_indices_validity::with_capacity(indices.len()); - let offsets = indices.iter().map(|index| { + let lengths = indices.iter().map(|index| { match index { Some(index) => { let index = index.to_usize(); @@ -133,24 +145,20 @@ pub(super) unsafe fn take_values_indices_validity { validity.push(false); - starts.push(O::default()); + starts.push_unchecked(O::default()); }, }; - length + length.to_usize() }); - let offsets = std::iter::once(O::default()) - .chain(offsets) - .collect::>(); - // Safety: by construction offsets are monotonically increasing - let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + let offsets = create_offsets(lengths, indices.len()); let buffer = take_values(length, &starts, &offsets, values_values); diff --git a/crates/polars-arrow/src/compute/take/list.rs b/crates/polars-arrow/src/compute/take/list.rs index b3057c12c0c7..e43a91421afa 100644 --- a/crates/polars-arrow/src/compute/take/list.rs +++ b/crates/polars-arrow/src/compute/take/list.rs @@ -17,13 +17,14 @@ use super::Index; use crate::array::growable::{Growable, GrowableList}; -use crate::array::{ListArray, PrimitiveArray}; +use crate::array::ListArray; +use crate::datatypes::IdxArr; use crate::offset::Offset; /// `take` implementation for ListArrays -pub(super) unsafe fn take_unchecked( +pub(super) unsafe fn take_unchecked( values: &ListArray, - indices: &PrimitiveArray, + indices: &IdxArr, ) -> ListArray { let mut capacity = 0; let arrays = indices diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index 1abf854a4a65..34b62802dc12 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -17,29 +17,28 @@ //! Defines take kernel for [`Array`] -use crate::array::{new_empty_array, Array, NullArray, PrimitiveArray}; +use crate::array::{new_empty_array, Array, NullArray, Utf8ViewArray}; +use crate::compute::take::binview::take_binview_unchecked; +use crate::datatypes::IdxArr; use crate::types::Index; mod binary; +mod binview; +mod bitmap; mod boolean; -mod dict; mod fixed_size_list; mod generic_binary; mod list; mod primitive; mod structure; -mod utf8; -use crate::{match_integer_type, with_match_primitive_type}; +use crate::with_match_primitive_type_full; /// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. /// The returned array has a length equal to `indices.len()`. /// # Safety /// Doesn't do bound checks -pub unsafe fn take_unchecked( - values: &dyn Array, - indices: &PrimitiveArray, -) -> Box { +pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { if indices.len() == 0 { return new_empty_array(values.data_type().clone()); } @@ -49,37 +48,36 @@ pub unsafe fn take_unchecked( Null => Box::new(NullArray::new(values.data_type().clone(), indices.len())), Boolean => { let values = values.as_any().downcast_ref().unwrap(); - Box::new(boolean::take_unchecked::(values, indices)) + Box::new(boolean::take_unchecked(values, indices)) }, - Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let values = values.as_any().downcast_ref().unwrap(); - Box::new(primitive::take_unchecked::<$T, _>(&values, indices)) + Box::new(primitive::take_primitive_unchecked::<$T>(&values, indices)) }), - LargeUtf8 => { - let values = values.as_any().downcast_ref().unwrap(); - Box::new(utf8::take_unchecked::(values, indices)) - }, LargeBinary => { let values = values.as_any().downcast_ref().unwrap(); Box::new(binary::take_unchecked::(values, indices)) }, - Dictionary(key_type) => { - match_integer_type!(key_type, |$T| { - let values = values.as_any().downcast_ref().unwrap(); - Box::new(dict::take_unchecked::<$T, _>(&values, indices)) - }) - }, Struct => { let array = values.as_any().downcast_ref().unwrap(); - structure::take_unchecked::<_>(array, indices).boxed() + structure::take_unchecked(array, indices).boxed() }, LargeList => { let array = values.as_any().downcast_ref().unwrap(); - Box::new(list::take_unchecked::(array, indices)) + Box::new(list::take_unchecked::(array, indices)) }, FixedSizeList => { let array = values.as_any().downcast_ref().unwrap(); - Box::new(fixed_size_list::take_unchecked::(array, indices)) + Box::new(fixed_size_list::take_unchecked(array, indices)) + }, + BinaryView => { + take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed() + }, + Utf8View => { + let arr: &Utf8ViewArray = values.as_any().downcast_ref().unwrap(); + take_binview_unchecked(&arr.to_binview(), indices) + .to_utf8view_unchecked() + .boxed() }, t => unimplemented!("Take not supported for data type {:?}", t), } diff --git a/crates/polars-arrow/src/compute/take/primitive.rs b/crates/polars-arrow/src/compute/take/primitive.rs index c8db6fac319c..039b64bac680 100644 --- a/crates/polars-arrow/src/compute/take/primitive.rs +++ b/crates/polars-arrow/src/compute/take/primitive.rs @@ -1,101 +1,80 @@ -use super::Index; -use crate::array::{Array, PrimitiveArray}; +use polars_utils::index::NullCount; +use polars_utils::slice::GetSaferUnchecked; + +use crate::array::PrimitiveArray; use crate::bitmap::{Bitmap, MutableBitmap}; -use crate::buffer::Buffer; +use crate::legacy::bit_util::unset_bit_raw; +use crate::legacy::index::IdxArr; +use crate::legacy::utils::CustomIterTools; use crate::types::NativeType; -// take implementation when neither values nor indices contain nulls -unsafe fn take_no_validity( +pub(super) unsafe fn take_values_and_validity_unchecked( values: &[T], - indices: &[I], -) -> (Buffer, Option) { - let values = indices - .iter() - .map(|index| *values.get_unchecked(index.to_usize())) - .collect::>(); - - (values.into(), None) -} - -// take implementation when only values contain nulls -unsafe fn take_values_validity( - values: &PrimitiveArray, - indices: &[I], -) -> (Buffer, Option) { - let values_validity = values.validity().unwrap(); + validity_values: Option<&Bitmap>, + indices: &IdxArr, +) -> (Vec, Option) { + let index_values = indices.values().as_slice(); - let validity = indices - .iter() - .map(|index| values_validity.get_bit_unchecked(index.to_usize())); - let validity = MutableBitmap::from_trusted_len_iter(validity); - - let values_values = values.values(); - - let values = indices - .iter() - .map(|index| *values_values.get_unchecked(index.to_usize())) - .collect::>(); - - (values.into(), validity.into()) -} + let null_count = validity_values.map(|b| b.unset_bits()).unwrap_or(0); -// take implementation when only indices contain nulls -unsafe fn take_indices_validity( - values: &[T], - indices: &PrimitiveArray, -) -> (Buffer, Option) { - let values = indices - .values() - .iter() - .map(|index| { - let index = index.to_usize(); - *values.get_unchecked(index) - }) - .collect::>(); - - (values.into(), indices.validity().cloned()) -} - -// take implementation when both values and indices contain nulls -unsafe fn take_values_indices_validity( - values: &PrimitiveArray, - indices: &PrimitiveArray, -) -> (Buffer, Option) { - let mut bitmap = MutableBitmap::with_capacity(indices.len()); + // first take the values, these are always needed + let values: Vec = if indices.null_count() == 0 { + index_values + .iter() + .map(|idx| *values.get_unchecked_release(*idx as usize)) + .collect_trusted() + } else { + indices + .iter() + .map(|idx| match idx { + Some(idx) => *values.get_unchecked_release(*idx as usize), + None => T::default(), + }) + .collect_trusted() + }; - let values_validity = values.validity().unwrap(); + if null_count > 0 { + let validity_values = validity_values.unwrap(); + // the validity buffer we will fill with all valid. And we unset the ones that are null + // in later checks + // this is in the assumption that most values will be valid. + // Maybe we could add another branch based on the null count + let mut validity = MutableBitmap::with_capacity(indices.len()); + validity.extend_constant(indices.len(), true); + let validity_ptr = validity.as_slice().as_ptr() as *mut u8; - let values_values = values.values(); - let values = indices - .iter() - .map(|index| match index { - Some(index) => { - let index = index.to_usize(); - bitmap.push(values_validity.get_bit_unchecked(index)); - *values_values.get_unchecked(index) - }, - None => { - bitmap.push(false); - T::default() - }, - }) - .collect::>(); - (values.into(), bitmap.into()) + if let Some(validity_indices) = indices.validity().as_ref() { + index_values.iter().enumerate().for_each(|(i, idx)| { + // i is iteration count + // idx is the index that we take from the values array. + let idx = *idx as usize; + if !validity_indices.get_bit_unchecked(i) || !validity_values.get_bit_unchecked(idx) + { + unset_bit_raw(validity_ptr, i); + } + }); + } else { + index_values.iter().enumerate().for_each(|(i, idx)| { + let idx = *idx as usize; + if !validity_values.get_bit_unchecked(idx) { + unset_bit_raw(validity_ptr, i); + } + }); + }; + (values, Some(validity.freeze())) + } else { + (values, indices.validity().cloned()) + } } -/// `take` implementation for primitive arrays -pub(super) unsafe fn take_unchecked( - values: &PrimitiveArray, - indices: &PrimitiveArray, +/// Take kernel for single chunk with nulls and arrow array as index that may have nulls. +/// # Safety +/// caller must ensure indices are in bounds +pub unsafe fn take_primitive_unchecked( + arr: &PrimitiveArray, + indices: &IdxArr, ) -> PrimitiveArray { - let indices_has_validity = indices.null_count() > 0; - let values_has_validity = values.null_count() > 0; - let (buffer, validity) = match (values_has_validity, indices_has_validity) { - (false, false) => take_no_validity::(values.values(), indices.values()), - (true, false) => take_values_validity::(values, indices.values()), - (false, true) => take_indices_validity::(values.values(), indices), - (true, true) => take_values_indices_validity::(values, indices), - }; - - PrimitiveArray::::new(values.data_type().clone(), buffer, validity) + let (values, validity) = + take_values_and_validity_unchecked(arr.values(), arr.validity(), indices); + PrimitiveArray::new_unchecked(arr.data_type().clone(), values.into(), validity) } diff --git a/crates/polars-arrow/src/compute/take/structure.rs b/crates/polars-arrow/src/compute/take/structure.rs index b186e5e0e352..bd9be54dc4b0 100644 --- a/crates/polars-arrow/src/compute/take/structure.rs +++ b/crates/polars-arrow/src/compute/take/structure.rs @@ -15,47 +15,20 @@ // specific language governing permissions and limitations // under the License. -use super::Index; -use crate::array::{Array, PrimitiveArray, StructArray}; -use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::array::{Array, StructArray}; +use crate::compute::utils::combine_validities_and; +use crate::datatypes::IdxArr; -#[inline] -unsafe fn take_validity( - validity: Option<&Bitmap>, - indices: &PrimitiveArray, -) -> Option { - let indices_validity = indices.validity(); - match (validity, indices_validity) { - (None, _) => indices_validity.cloned(), - (Some(validity), None) => { - let iter = indices.values().iter().map(|index| { - let index = index.to_usize(); - validity.get_bit_unchecked(index) - }); - MutableBitmap::from_trusted_len_iter(iter).into() - }, - (Some(validity), _) => { - let iter = indices.iter().map(|x| match x { - Some(index) => { - let index = index.to_usize(); - validity.get_bit_unchecked(index) - }, - None => false, - }); - MutableBitmap::from_trusted_len_iter(iter).into() - }, - } -} - -pub(super) unsafe fn take_unchecked( - array: &StructArray, - indices: &PrimitiveArray, -) -> StructArray { +pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> StructArray { let values: Vec> = array .values() .iter() .map(|a| super::take_unchecked(a.as_ref(), indices)) .collect(); - let validity = take_validity(array.validity(), indices); + + let validity = array + .validity() + .map(|b| super::bitmap::take_bitmap_unchecked(b, indices.values())); + let validity = combine_validities_and(validity.as_ref(), indices.validity()); StructArray::new(array.data_type().clone(), values, validity) } diff --git a/crates/polars-arrow/src/compute/take/utf8.rs b/crates/polars-arrow/src/compute/take/utf8.rs deleted file mode 100644 index 69806e0472ae..000000000000 --- a/crates/polars-arrow/src/compute/take/utf8.rs +++ /dev/null @@ -1,41 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::generic_binary::*; -use super::Index; -use crate::array::{Array, PrimitiveArray, Utf8Array}; -use crate::offset::Offset; - -/// `take` implementation for utf8 arrays -pub unsafe fn take_unchecked( - values: &Utf8Array, - indices: &PrimitiveArray, -) -> Utf8Array { - let data_type = values.data_type().clone(); - let indices_has_validity = indices.null_count() > 0; - let values_has_validity = values.null_count() > 0; - - let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { - (false, false) => { - take_no_validity::(values.offsets(), values.values(), indices.values()) - }, - (true, false) => take_values_validity(values, indices.values()), - (false, true) => take_indices_validity(values.offsets(), values.values(), indices), - (true, true) => take_values_indices_validity(values, indices), - }; - unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } -} diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index c5e5777e84e2..edac9c8032d0 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -3,7 +3,24 @@ use std::ops::{BitAnd, BitOr}; use polars_error::{polars_bail, polars_ensure, PolarsResult}; use crate::array::Array; -use crate::bitmap::Bitmap; +use crate::bitmap::{ternary, Bitmap}; + +pub fn combine_validities_and3( + opt1: Option<&Bitmap>, + opt2: Option<&Bitmap>, + opt3: Option<&Bitmap>, +) -> Option { + match (opt1, opt2, opt3) { + (Some(a), Some(b), Some(c)) => Some(ternary(a, b, c, |x, y, z| x & y & z)), + (Some(a), Some(b), None) => Some(a.bitand(b)), + (Some(a), None, Some(c)) => Some(a.bitand(c)), + (None, Some(b), Some(c)) => Some(b.bitand(c)), + (Some(a), None, None) => Some(a.clone()), + (None, Some(b), None) => Some(b.clone()), + (None, None, Some(c)) => Some(c.clone()), + (None, None, None) => None, + } +} pub fn combine_validities_and(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> Option { match (opt_l, opt_r) { diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 8fb4019f76f4..95e64447293f 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -592,3 +592,8 @@ pub fn get_extension(metadata: &Metadata) -> Extension { None } } + +#[cfg(not(feature = "bigidx"))] +pub type IdxArr = super::array::UInt32Array; +#[cfg(feature = "bigidx")] +pub type IdxArr = super::array::UInt64Array; diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs index 1e6581ee7550..5de9176e3980 100644 --- a/crates/polars-arrow/src/ffi/array.rs +++ b/crates/polars-arrow/src/ffi/array.rs @@ -93,6 +93,7 @@ struct PrivateData { impl ArrowArray { /// creates a new `ArrowArray` from existing data. + /// /// # Safety /// This method releases `buffers`. Consumers of this struct *must* call `release` before /// releasing this struct, or contents in `buffers` leak. @@ -411,7 +412,8 @@ unsafe fn buffer_len( }) } -/// Safety +/// # Safety +/// /// This function is safe iff: /// * `array.children` at `index` is valid /// * `array.children` is not mutably shared for the lifetime of `parent` @@ -437,7 +439,7 @@ unsafe fn create_child( ); } - // Safety - part of the invariant + // SAFETY: part of the invariant let arr_ptr = unsafe { *array.children.add(index) }; // catch what we can @@ -448,12 +450,13 @@ unsafe fn create_child( ) } - // Safety - invariant of this function + // SAFETY: invariant of this function let arr_ptr = unsafe { &*arr_ptr }; Ok(ArrowArrayChild::new(arr_ptr, data_type, parent)) } -/// Safety +/// # Safety +/// /// This function is safe iff: /// * `array.dictionary` is valid /// * `array.dictionary` is not mutably shared for the lifetime of `parent` @@ -472,7 +475,7 @@ unsafe fn create_dictionary( ) } - // safety: part of the invariant + // SAFETY: part of the invariant let array = unsafe { &*array.dictionary }; Ok(Some(ArrowArrayChild::new(array, data_type, parent))) } else { @@ -488,6 +491,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { /// returns the null bit buffer. /// Rust implementation uses a buffer that is not part of the array of buffers. /// The C Data interface's null buffer is part of the array of buffers. + /// /// # Safety /// The caller must guarantee that the buffer `index` corresponds to a bitmap. /// This function assumes that the bitmap created from FFI is valid; this is impossible to prove. diff --git a/crates/polars-arrow/src/ffi/mmap.rs b/crates/polars-arrow/src/ffi/mmap.rs index bdea26558f29..2267196272d6 100644 --- a/crates/polars-arrow/src/ffi/mmap.rs +++ b/crates/polars-arrow/src/ffi/mmap.rs @@ -118,7 +118,7 @@ pub unsafe fn slice_and_owner(slice: &[T], owner: O) -> Primit let ptr = data.as_ptr(); let data = Arc::new(owner); - // safety: the underlying assumption of this function: the array will not be used + // SAFETY: the underlying assumption of this function: the array will not be used // beyond the let array = create_array( data, @@ -131,7 +131,7 @@ pub unsafe fn slice_and_owner(slice: &[T], owner: O) -> Primit ); let array = InternalArrowArray::new(array, T::PRIMITIVE.into()); - // safety: we just created a valid array + // SAFETY: we just created a valid array unsafe { PrimitiveArray::::try_from_ffi(array) }.unwrap() } @@ -181,7 +181,7 @@ pub unsafe fn bitmap_and_owner( let ptr = data.as_ptr(); let data = Arc::new(owner); - // safety: the underlying assumption of this function: the array will not be used + // SAFETY: the underlying assumption of this function: the array will not be used // beyond the let array = create_array( data, @@ -194,6 +194,6 @@ pub unsafe fn bitmap_and_owner( ); let array = InternalArrowArray::new(array, ArrowDataType::Boolean); - // safety: we just created a valid array + // SAFETY: we just created a valid array Ok(unsafe { BooleanArray::try_from_ffi(array) }.unwrap()) } diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index 09e09e0494b3..23cf9c8c4a47 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::convert::TryInto; use std::ffi::{CStr, CString}; use std::ptr; diff --git a/crates/polars-arrow/src/ffi/stream.rs b/crates/polars-arrow/src/ffi/stream.rs index 7b7171432f73..58a0b0785529 100644 --- a/crates/polars-arrow/src/ffi/stream.rs +++ b/crates/polars-arrow/src/ffi/stream.rs @@ -54,6 +54,7 @@ impl> ArrowArrayStreamReader { /// # Error /// Errors iff the [`ArrowArrayStream`] is out of specification, /// or was already released prior to calling this function. + /// /// # Safety /// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream` /// contains a valid Arrow C stream interface. @@ -101,6 +102,7 @@ impl> ArrowArrayStreamReader { /// Errors iff: /// * The C stream interface returns an error /// * The C stream interface returns an invalid array (that we can identify, see Safety below) + /// /// # Safety /// Calling this iterator's `next` assumes that the [`ArrowArrayStream`] produces arrow arrays /// that fulfill the C data interface @@ -115,7 +117,7 @@ impl> ArrowArrayStreamReader { // last paragraph of https://arrow.apache.org/docs/format/CStreamInterface.html#c.ArrowArrayStream.get_next array.release?; - // Safety: assumed from the C stream interface + // SAFETY: assumed from the C stream interface unsafe { import_array_from_c(array, self.field.data_type.clone()) } .map(Some) .transpose() diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index 3eaa556f5672..a2c8c83cc9b0 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use avro_schema::file::Block; use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; use polars_error::{polars_bail, polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/array/binview.rs b/crates/polars-arrow/src/io/ipc/read/array/binview.rs index 40905c740e97..337f9373423c 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/binview.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/binview.rs @@ -1,14 +1,12 @@ -use std::collections::VecDeque; use std::io::{Read, Seek}; use std::sync::Arc; -use polars_error::{polars_err, PolarsResult}; +use polars_error::polars_err; use super::super::read_basic::*; use super::*; use crate::array::{ArrayRef, BinaryViewArrayGeneric, View, ViewType}; use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; #[allow(clippy::too_many_arguments)] pub fn read_binview( diff --git a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs index 846a2a8ea8fa..5a43fe21e102 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use ahash::HashSet; diff --git a/crates/polars-arrow/src/io/ipc/read/array/list.rs b/crates/polars-arrow/src/io/ipc/read/array/list.rs index c36646fe0192..f29cda5834ad 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/list.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use polars_error::{polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs index 24b2a05ec6a4..04304aadca90 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use polars_error::{polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs index 1408ff41435e..f29f8d8cdb26 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs @@ -1,13 +1,11 @@ -use std::collections::VecDeque; use std::io::{Read, Seek}; -use polars_error::{polars_err, PolarsResult}; +use polars_error::polars_err; use super::super::read_basic::*; use super::*; use crate::array::Utf8Array; use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; use crate::offset::Offset; #[allow(clippy::too_many_arguments)] diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index 87005dc76cc4..49638ab2317b 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -2,7 +2,6 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; use ahash::AHashMap; -use arrow_format; use polars_error::{polars_bail, polars_err, PolarsResult}; use super::deserialize::{read, skip}; diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index 6f1f4ca8f511..91d276326430 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -1,8 +1,8 @@ -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::FooterRef; use polars_error::{polars_bail, polars_err, PolarsResult}; use polars_utils::aliases::{InitHashMaps, PlHashMap}; @@ -62,6 +62,22 @@ fn read_dictionary_message( Ok(()) } +/// Read the row count by summing the length of the of the record batches +pub fn get_row_count(reader: &mut R) -> PolarsResult { + let mut message_scratch: Vec = Default::default(); + let (_, footer_len) = read_footer_len(reader)?; + let footer = read_footer(reader, footer_len)?; + let (_, blocks) = deserialize_footer_blocks(&footer)?; + + blocks + .into_iter() + .map(|block| { + let message = get_message_from_block(reader, block, &mut message_scratch)?; + let record_batch = get_record_batch(message)?; + record_batch.length().map_err(|e| e.into()) + }) + .sum() +} pub(crate) fn get_dictionary_batch<'a>( message: &'a arrow_format::ipc::MessageRef, @@ -152,6 +168,9 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC_V2 { + if footer[..4] == ARROW_MAGIC_V1 { + polars_bail!(ComputeError: "feather v1 not supported"); + } return Err(polars_err!(oos = OutOfSpecKind::InvalidFooter)); } let footer_len = footer_len @@ -161,7 +180,22 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> Ok((end, footer_len)) } -pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { +fn read_footer(reader: &mut R, footer_len: usize) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64))?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + reader + .by_ref() + .take(footer_len as u64) + .read_to_end(&mut serialized_footer)?; + Ok(serialized_footer) +} + +fn deserialize_footer_blocks( + footer_data: &[u8], +) -> PolarsResult<(FooterRef, Vec)> { let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; @@ -178,6 +212,11 @@ pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult< }) }) .collect::>>()?; + Ok((footer, blocks)) +} + +pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { + let (footer, blocks) = deserialize_footer_blocks(footer_data)?; let ipc_schema = footer .schema() @@ -211,29 +250,9 @@ pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult< /// Read the Arrow IPC file's metadata pub fn read_file_metadata(reader: &mut R) -> PolarsResult { - // check if header contain the correct magic bytes - let mut magic_buffer: [u8; 6] = [0; 6]; let start = reader.stream_position()?; - reader.read_exact(&mut magic_buffer)?; - if magic_buffer != ARROW_MAGIC_V2 { - if magic_buffer[..4] == ARROW_MAGIC_V1 { - polars_bail!(ComputeError: "feather v1 not supported"); - } - polars_bail!(oos = OutOfSpecKind::InvalidHeader); - } - let (end, footer_len) = read_footer_len(reader)?; - - // read footer - reader.seek(SeekFrom::End(-10 - footer_len as i64))?; - - let mut serialized_footer = vec![]; - serialized_footer.try_reserve(footer_len)?; - reader - .by_ref() - .take(footer_len as u64) - .read_to_end(&mut serialized_footer)?; - + let serialized_footer = read_footer(reader, footer_len)?; deserialize_footer(&serialized_footer, end - start) } @@ -250,6 +269,47 @@ pub(crate) fn get_record_batch( } } +fn get_message_from_block_offset<'a, R: Read + Seek>( + reader: &mut R, + offset: u64, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + // read length + reader.seek(SeekFrom::Start(offset))?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next + reader.read_exact(&mut meta_buf)?; + } + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + message_scratch.clear(); + message_scratch.try_reserve(meta_len)?; + reader + .by_ref() + .take(meta_len as u64) + .read_to_end(message_scratch)?; + + arrow_format::ipc::MessageRef::read_as_root(message_scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +fn get_message_from_block<'a, R: Read + Seek>( + reader: &mut R, + block: arrow_format::ipc::Block, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + get_message_from_block_offset(reader, offset, message_scratch) +} + /// Reads the record batch at position `index` from the reader. /// /// This function is useful for random access to the file. For example, if @@ -280,28 +340,7 @@ pub fn read_batch( .try_into() .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - // read length - reader.seek(SeekFrom::Start(offset))?; - let mut meta_buf = [0; 4]; - reader.read_exact(&mut meta_buf)?; - if meta_buf == CONTINUATION_MARKER { - // continuation marker encountered, read message next - reader.read_exact(&mut meta_buf)?; - } - let meta_len = i32::from_le_bytes(meta_buf) - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; - - message_scratch.clear(); - message_scratch.try_reserve(meta_len)?; - reader - .by_ref() - .take(meta_len as u64) - .read_to_end(message_scratch)?; - - let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - + let message = get_message_from_block_offset(reader, offset, message_scratch)?; let batch = get_record_batch(message)?; read_record_batch( diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs index 3688816273e5..f79376934427 100644 --- a/crates/polars-arrow/src/io/ipc/read/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -17,6 +17,7 @@ mod schema; mod stream; pub use error::OutOfSpecKind; +pub use file::get_row_count; #[cfg(feature = "io_ipc_read_async")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs index 3864b24bf26c..7dbea604ccf3 100644 --- a/crates/polars-arrow/src/io/ipc/read/read_basic.rs +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use polars_error::{polars_bail, polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs index 5fab1d826211..72b076364294 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -1,7 +1,6 @@ use std::io::Read; use ahash::AHashMap; -use arrow_format; use arrow_format::ipc::planus::ReadAsRoot; use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index 1d4375280838..7fb4a9e5c0ee 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -280,10 +280,32 @@ fn chunk_to_bytes_amortized( let mut offset = 0; let mut variadic_buffer_counts = vec![]; for array in chunk.arrays() { - set_variadic_buffer_counts(&mut variadic_buffer_counts, array.as_ref()); + // We don't want to write all buffers in sliced arrays. + let array = match array.data_type() { + ArrowDataType::BinaryView => { + let concrete_arr = array.as_any().downcast_ref::().unwrap(); + if concrete_arr.is_sliced() { + Cow::Owned(concrete_arr.clone().maybe_gc().boxed()) + } else { + Cow::Borrowed(array) + } + }, + ArrowDataType::Utf8View => { + let concrete_arr = array.as_any().downcast_ref::().unwrap(); + if concrete_arr.is_sliced() { + Cow::Owned(concrete_arr.clone().maybe_gc().boxed()) + } else { + Cow::Borrowed(array) + } + }, + _ => Cow::Borrowed(array), + }; + let array = array.as_ref().as_ref(); + + set_variadic_buffer_counts(&mut variadic_buffer_counts, array); write( - array.as_ref(), + array, &mut buffers, &mut arrow_data, &mut nodes, diff --git a/crates/polars-arrow/src/legacy/array/list.rs b/crates/polars-arrow/src/legacy/array/list.rs index 7dcb6e5de3a2..ff02011663cb 100644 --- a/crates/polars-arrow/src/legacy/array/list.rs +++ b/crates/polars-arrow/src/legacy/array/list.rs @@ -41,7 +41,7 @@ impl<'a> AnonymousBuilder<'a> { } pub fn take_offsets(self) -> Offsets { - // safety: offsets are correct + // SAFETY: offsets are correct unsafe { Offsets::new_unchecked(self.offsets) } } @@ -102,7 +102,7 @@ impl<'a> AnonymousBuilder<'a> { } pub fn finish(self, inner_dtype: Option<&ArrowDataType>) -> PolarsResult> { - // Safety: + // SAFETY: // offsets are monotonically increasing let offsets = unsafe { Offsets::new_unchecked(self.offsets) }; let (inner_dtype, values) = if self.arrays.is_empty() { diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index 1e6d59bb430d..489cfbb0dbeb 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -67,7 +67,7 @@ pub trait ListFromIter { let values: PrimitiveArray = iter_to_values!(iterator, validity, offsets, length_so_far); - // Safety: + // SAFETY: // offsets are monotonically increasing ListArray::new( ListArray::::default_datatype(data_type.clone()), @@ -97,7 +97,7 @@ pub trait ListFromIter { let values: BooleanArray = iter_to_values!(iterator, validity, offsets, length_so_far); - // Safety: + // SAFETY: // Offsets are monotonically increasing. ListArray::new( ListArray::::default_datatype(ArrowDataType::Boolean), @@ -145,7 +145,7 @@ pub trait ListFromIter { .trust_my_length(n_elements) .collect(); - // Safety: + // SAFETY: // offsets are monotonically increasing ListArray::new( ListArray::::default_datatype(T::DATA_TYPE), diff --git a/crates/polars-arrow/src/legacy/array/slice.rs b/crates/polars-arrow/src/legacy/array/slice.rs index 63997e78b88c..720723c901a8 100644 --- a/crates/polars-arrow/src/legacy/array/slice.rs +++ b/crates/polars-arrow/src/legacy/array/slice.rs @@ -15,6 +15,7 @@ pub trait SlicedArray { /// Slices the [`Array`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` unsafe fn slice_typed_unchecked(&self, offset: usize, length: usize) -> Self diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/add.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/add.rs deleted file mode 100644 index 17089326d36f..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/add.rs +++ /dev/null @@ -1,16 +0,0 @@ -use super::*; - -pub fn add( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - commutative(lhs, rhs, |a, b| a + b) -} - -pub fn add_scalar( - lhs: &PrimitiveArray, - rhs: i128, - rhs_dtype: &ArrowDataType, -) -> PolarsResult> { - commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a + b) -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/commutative.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/commutative.rs deleted file mode 100644 index 9b18d45234ae..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/commutative.rs +++ /dev/null @@ -1,89 +0,0 @@ -use polars_error::*; - -use super::{get_parameters, max_value}; -use crate::array::PrimitiveArray; -use crate::datatypes::ArrowDataType; -use crate::legacy::compute::{binary_mut, unary_mut}; - -pub fn commutative( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - op: F, -) -> PolarsResult> -where - F: Fn(i128, i128) -> i128, -{ - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - - let max = max_value(precision); - let mut overflow = false; - let op = |a, b| { - let res = op(a, b); - overflow |= res.abs() > max; - res - }; - let out = binary_mut(lhs, rhs, lhs.data_type().clone(), op); - polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - Ok(out) -} - -pub fn commutative_scalar( - lhs: &PrimitiveArray, - rhs: i128, - rhs_dtype: &ArrowDataType, - op: F, -) -> PolarsResult> -where - F: Fn(i128, i128) -> i128, -{ - let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype)?; - - let max = max_value(precision); - let mut overflow = false; - let op = |a| { - let res = op(a, rhs); - overflow |= res.abs() > max; - res - }; - let out = unary_mut(lhs, op, lhs.data_type().clone()); - polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - - Ok(out) -} - -pub fn non_commutative( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - op: F, -) -> PolarsResult> -where - F: Fn(i128, i128) -> i128, -{ - Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) -} - -pub fn non_commutative_scalar( - lhs: &PrimitiveArray, - rhs: i128, - op: F, -) -> PolarsResult> -where - F: Fn(i128, i128) -> i128, -{ - let op = move |a| op(a, rhs); - - Ok(unary_mut(lhs, op, lhs.data_type().clone())) -} - -pub fn non_commutative_scalar_swapped( - lhs: i128, - rhs: &PrimitiveArray, - op: F, -) -> PolarsResult> -where - F: Fn(i128, i128) -> i128, -{ - let op = move |a| op(lhs, a); - - Ok(unary_mut(rhs, op, rhs.data_type().clone())) -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/div.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/div.rs deleted file mode 100644 index cb600d8f781a..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/div.rs +++ /dev/null @@ -1,50 +0,0 @@ -use ethnum::I256; - -use super::*; - -#[inline] -fn decimal_div(a: i128, b: i128, scale: i128) -> i128 { - // The division is done using the numbers without scale. - // The dividend is scaled up to maintain precision after the - // division - - // 222.222 --> 222222000 - // 123.456 --> 123456 - // -------- --------- - // 1.800 <-- 1800 - - // operate in I256 space to reduce overflow - let a = I256::new(a); - let b = I256::new(b); - let scale = I256::new(scale); - (a * scale / b).as_i128() -} - -pub fn div( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - let scale = 10i128.pow(scale as u32); - non_commutative(lhs, rhs, |a, b| decimal_div(a, b, scale)) -} - -pub fn div_scalar( - lhs: &PrimitiveArray, - rhs: i128, - rhs_dtype: &ArrowDataType, -) -> PolarsResult> { - let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; - let scale = 10i128.pow(scale as u32); - non_commutative_scalar(lhs, rhs, |a, b| decimal_div(a, b, scale)) -} - -pub fn div_scalar_swapped( - lhs: i128, - lhs_dtype: &ArrowDataType, - rhs: &PrimitiveArray, -) -> PolarsResult> { - let (_, scale) = get_parameters(lhs_dtype, rhs.data_type())?; - let scale = 10i128.pow(scale as u32); - non_commutative_scalar_swapped(lhs, rhs, |a, b| decimal_div(a, b, scale)) -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mod.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mod.rs deleted file mode 100644 index 52a9765129b6..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mod.rs +++ /dev/null @@ -1,41 +0,0 @@ -use commutative::{ - commutative, commutative_scalar, non_commutative, non_commutative_scalar, - non_commutative_scalar_swapped, -}; -use polars_error::{PolarsError, PolarsResult}; - -use crate::array::PrimitiveArray; -use crate::datatypes::ArrowDataType; - -mod add; -mod commutative; -mod div; -mod mul; -mod sub; - -pub use add::*; -pub use div::*; -pub use mul::*; -pub use sub::*; - -/// Maximum value that can exist with a selected precision -#[inline] -fn max_value(precision: usize) -> i128 { - 10i128.pow(precision as u32) - 1 -} - -fn get_parameters(lhs: &ArrowDataType, rhs: &ArrowDataType) -> PolarsResult<(usize, usize)> { - if let (ArrowDataType::Decimal(lhs_p, lhs_s), ArrowDataType::Decimal(rhs_p, rhs_s)) = - (lhs.to_logical_type(), rhs.to_logical_type()) - { - if lhs_p == rhs_p && lhs_s == rhs_s { - Ok((*lhs_p, *lhs_s)) - } else { - Err(PolarsError::InvalidOperation( - "Arrays must have the same precision and scale".into(), - )) - } - } else { - unreachable!() - } -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mul.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mul.rs deleted file mode 100644 index 7e6640444011..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/mul.rs +++ /dev/null @@ -1,41 +0,0 @@ -use ethnum::I256; - -use super::*; - -#[inline] -fn decimal_mul(a: i128, b: i128, scale: i128) -> i128 { - // The multiplication is done using the numbers without scale. - // The resulting scale of the value has to be corrected by - // dividing by (10^scale) - - // 111.111 --> 111111 - // 222.222 --> 222222 - // -------- ------- - // 24691.308 <-- 24691308642 - - // operate in I256 space to reduce overflow - let a = I256::new(a); - let b = I256::new(b); - let scale = I256::new(scale); - - (a * b / scale).as_i128() -} - -pub fn mul( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - let scale = 10i128.pow(scale as u32); - commutative(lhs, rhs, |a, b| decimal_mul(a, b, scale)) -} - -pub fn mul_scalar( - lhs: &PrimitiveArray, - rhs: i128, - rhs_dtype: &ArrowDataType, -) -> PolarsResult> { - let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; - let scale = 10i128.pow(scale as u32); - commutative_scalar(lhs, rhs, rhs_dtype, |a, b| decimal_mul(a, b, scale)) -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/sub.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/sub.rs deleted file mode 100644 index da67a8593bde..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/decimal/sub.rs +++ /dev/null @@ -1,19 +0,0 @@ -use super::*; - -pub fn sub( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - non_commutative(lhs, rhs, |a, b| a - b) -} - -pub fn sub_scalar(lhs: &PrimitiveArray, rhs: i128) -> PolarsResult> { - non_commutative_scalar(lhs, rhs, |a, b| a - b) -} - -pub fn sub_scalar_swapped( - lhs: i128, - rhs: &PrimitiveArray, -) -> PolarsResult> { - non_commutative_scalar_swapped(lhs, rhs, |a, b| a - b) -} diff --git a/crates/polars-arrow/src/legacy/compute/arithmetics/mod.rs b/crates/polars-arrow/src/legacy/compute/arithmetics/mod.rs deleted file mode 100644 index 0abcbaba757a..000000000000 --- a/crates/polars-arrow/src/legacy/compute/arithmetics/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(feature = "dtype-decimal")] -pub mod decimal; diff --git a/crates/polars-arrow/src/legacy/compute/bitwise.rs b/crates/polars-arrow/src/legacy/compute/bitwise.rs deleted file mode 100644 index 487363028f0c..000000000000 --- a/crates/polars-arrow/src/legacy/compute/bitwise.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::ops::{BitAnd, BitOr, BitXor}; - -use crate::array::PrimitiveArray; -use crate::compute::arity::binary; -use crate::types::NativeType; - -pub fn bitand(a: &PrimitiveArray, b: &PrimitiveArray) -> PrimitiveArray -where - T: BitAnd, -{ - binary(a, b, a.data_type().clone(), |a, b| a.bitand(b)) -} - -pub fn bitor(a: &PrimitiveArray, b: &PrimitiveArray) -> PrimitiveArray -where - T: BitOr, -{ - binary(a, b, a.data_type().clone(), |a, b| a.bitor(b)) -} - -pub fn bitxor(a: &PrimitiveArray, b: &PrimitiveArray) -> PrimitiveArray -where - T: BitXor, -{ - binary(a, b, a.data_type().clone(), |a, b| a.bitxor(b)) -} diff --git a/crates/polars-arrow/src/legacy/compute/mod.rs b/crates/polars-arrow/src/legacy/compute/mod.rs index 9bdba88e7d7a..fe5cfb198ba8 100644 --- a/crates/polars-arrow/src/legacy/compute/mod.rs +++ b/crates/polars-arrow/src/legacy/compute/mod.rs @@ -3,11 +3,9 @@ use crate::compute::utils::combine_validities_and; use crate::datatypes::ArrowDataType; use crate::types::NativeType; -pub mod arithmetics; -pub mod bitwise; #[cfg(feature = "dtype-decimal")] pub mod decimal; -pub mod take; +// pub mod take; pub mod tile; #[inline] diff --git a/crates/polars-arrow/src/legacy/compute/take/boolean.rs b/crates/polars-arrow/src/legacy/compute/take/boolean.rs deleted file mode 100644 index 049a3c4d5d9f..000000000000 --- a/crates/polars-arrow/src/legacy/compute/take/boolean.rs +++ /dev/null @@ -1,79 +0,0 @@ -use super::bitmap::take_bitmap_unchecked; -use crate::array::{Array, BooleanArray, PrimitiveArray}; -use crate::bitmap::{Bitmap, MutableBitmap}; -use crate::legacy::index::IdxSize; - -// take implementation when neither values nor indices contain nulls -unsafe fn take_no_validity(values: &Bitmap, indices: &[IdxSize]) -> (Bitmap, Option) { - (take_bitmap_unchecked(values, indices), None) -} - -// take implementation when only values contain nulls -unsafe fn take_values_validity( - values: &BooleanArray, - indices: &[IdxSize], -) -> (Bitmap, Option) { - let validity_values = values.validity().unwrap(); - let validity = take_bitmap_unchecked(validity_values, indices); - - let values_values = values.values(); - let buffer = take_bitmap_unchecked(values_values, indices); - - (buffer, validity.into()) -} - -// take implementation when only indices contain nulls -unsafe fn take_indices_validity( - values: &Bitmap, - indices: &PrimitiveArray, -) -> (Bitmap, Option) { - // simply take all and copy the bitmap - let buffer = take_bitmap_unchecked(values, indices.values()); - - (buffer, indices.validity().cloned()) -} - -// take implementation when both values and indices contain nulls -unsafe fn take_values_indices_validity( - values: &BooleanArray, - indices: &PrimitiveArray, -) -> (Bitmap, Option) { - let mut validity = MutableBitmap::with_capacity(indices.len()); - - let values_validity = values.validity().unwrap(); - - let values_values = values.values(); - let values = indices.iter().map(|index| match index { - Some(&index) => { - let index = index as usize; - debug_assert!(index < values.len()); - validity.push(values_validity.get_bit_unchecked(index)); - values_values.get_bit_unchecked(index) - }, - None => { - validity.push(false); - false - }, - }); - let values = Bitmap::from_trusted_len_iter(values); - (values, validity.into()) -} - -/// `take` implementation for boolean arrays -pub unsafe fn take_unchecked( - values: &BooleanArray, - indices: &PrimitiveArray, -) -> BooleanArray { - let data_type = values.data_type().clone(); - let indices_has_validity = indices.null_count() > 0; - let values_has_validity = values.null_count() > 0; - - let (values, validity) = match (values_has_validity, indices_has_validity) { - (false, false) => take_no_validity(values.values(), indices.values()), - (true, false) => take_values_validity(values, indices.values()), - (false, true) => take_indices_validity(values.values(), indices), - (true, true) => take_values_indices_validity(values, indices), - }; - - BooleanArray::new(data_type, values, validity) -} diff --git a/crates/polars-arrow/src/legacy/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/legacy/compute/take/fixed_size_list.rs deleted file mode 100644 index 7d6a6ba948ff..000000000000 --- a/crates/polars-arrow/src/legacy/compute/take/fixed_size_list.rs +++ /dev/null @@ -1,109 +0,0 @@ -use crate::array::growable::{Growable, GrowableFixedSizeList}; -use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray}; -use crate::bitmap::{Bitmap, MutableBitmap}; -use crate::datatypes::{ArrowDataType, PhysicalType}; -use crate::legacy::index::{IdxArr, IdxSize}; -use crate::types::NativeType; -use crate::with_match_primitive_type; - -pub unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> FixedSizeListArray { - if let (PhysicalType::Primitive(primitive), 0) = ( - values.values().data_type().to_physical_type(), - indices.null_count(), - ) { - let idx = indices.values().as_slice(); - let child_values = values.values(); - let ArrowDataType::FixedSizeList(_, width) = values.data_type() else { - unreachable!() - }; - - with_match_primitive_type!(primitive, |$T| { - let arr: &PrimitiveArray<$T> = child_values.as_any().downcast_ref().unwrap(); - return take_unchecked_primitive(values, arr, idx, *width) - }) - } - - let mut capacity = 0; - let arrays = indices - .values() - .iter() - .map(|index| { - let index = *index as usize; - let slice = values.clone().sliced_unchecked(index, 1); - capacity += slice.len(); - slice - }) - .collect::>(); - - let arrays = arrays.iter().collect(); - - if let Some(validity) = indices.validity() { - let mut growable: GrowableFixedSizeList = - GrowableFixedSizeList::new(arrays, true, capacity); - - for index in 0..indices.len() { - if validity.get_bit(index) { - growable.extend(index, 0, 1); - } else { - growable.extend_validity(1) - } - } - - growable.into() - } else { - let mut growable: GrowableFixedSizeList = - GrowableFixedSizeList::new(arrays, false, capacity); - for index in 0..indices.len() { - growable.extend(index, 0, 1); - } - - growable.into() - } -} - -unsafe fn take_bitmap_unchecked(bitmap: &Bitmap, idx: &[IdxSize], width: usize) -> Bitmap { - let mut out = MutableBitmap::with_capacity(idx.len() * width); - let (slice, offset, _len) = bitmap.as_slice(); - - for &idx in idx { - out.extend_from_slice_unchecked(slice, offset + idx as usize * width, width) - } - out.into() -} - -unsafe fn take_unchecked_primitive( - parent: &FixedSizeListArray, - list_values: &PrimitiveArray, - idx: &[IdxSize], - width: usize, -) -> FixedSizeListArray { - let values = list_values.values().as_slice(); - let mut out = Vec::with_capacity(idx.len() * width); - - for &i in idx { - let start = i as usize * width; - let end = start + width; - out.extend_from_slice(values.get_unchecked(start..end)); - } - - let validity = if list_values.null_count() > 0 { - let validity = list_values.validity().unwrap(); - Some(take_bitmap_unchecked(validity, idx, width)) - } else { - None - }; - let list_values = Box::new(PrimitiveArray::new( - list_values.data_type().clone(), - out.into(), - validity, - )) as ArrayRef; - let validity = if parent.null_count() > 0 { - Some(super::bitmap::take_bitmap_unchecked( - parent.validity().unwrap(), - idx, - )) - } else { - None - }; - FixedSizeListArray::new(parent.data_type().clone(), list_values, validity) -} diff --git a/crates/polars-arrow/src/legacy/compute/take/mod.rs b/crates/polars-arrow/src/legacy/compute/take/mod.rs deleted file mode 100644 index dc1ebc9cf8cd..000000000000 --- a/crates/polars-arrow/src/legacy/compute/take/mod.rs +++ /dev/null @@ -1,248 +0,0 @@ -pub mod bitmap; -mod boolean; -#[cfg(feature = "dtype-array")] -mod fixed_size_list; - -use polars_utils::slice::GetSaferUnchecked; - -use crate::array::*; -use crate::bitmap::{Bitmap, MutableBitmap}; -use crate::datatypes::PhysicalType; -use crate::legacy::bit_util::unset_bit_raw; -use crate::legacy::prelude::*; -use crate::legacy::utils::CustomIterTools; -use crate::offset::Offsets; -use crate::types::NativeType; -use crate::with_match_primitive_type; - -/// # Safety -/// Does not do bounds checks -pub unsafe fn take_unchecked(arr: &dyn Array, idx: &IdxArr) -> ArrayRef { - if idx.null_count() == idx.len() { - return new_null_array(arr.data_type().clone(), idx.len()); - } - use PhysicalType::*; - match arr.data_type().to_physical_type() { - Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { - let arr: &PrimitiveArray<$T> = arr.as_any().downcast_ref().unwrap(); - take_primitive_unchecked::<$T>(arr, idx).boxed() - }), - Boolean => { - let arr = arr.as_any().downcast_ref().unwrap(); - Box::new(boolean::take_unchecked(arr, idx)) - }, - #[cfg(feature = "dtype-array")] - FixedSizeList => { - let arr = arr.as_any().downcast_ref().unwrap(); - Box::new(fixed_size_list::take_unchecked(arr, idx)) - }, - BinaryView => take_binview_unchecked(arr.as_any().downcast_ref().unwrap(), idx).boxed(), - Utf8View => { - let arr: &Utf8ViewArray = arr.as_any().downcast_ref().unwrap(); - take_binview_unchecked(&arr.to_binview(), idx) - .to_utf8view_unchecked() - .boxed() - }, - Struct => { - let array = arr.as_any().downcast_ref().unwrap(); - take_struct_unchecked(array, idx).boxed() - }, - // TODO! implement proper unchecked version - #[cfg(feature = "compute")] - _ => { - use crate::compute::take::take_unchecked; - take_unchecked(arr, idx) - }, - #[cfg(not(feature = "compute"))] - _ => { - panic!("activate compute feature") - }, - } -} - -unsafe fn take_validity_unchecked(validity: Option<&Bitmap>, indices: &IdxArr) -> Option { - let indices_validity = indices.validity(); - match (validity, indices_validity) { - (None, _) => indices_validity.cloned(), - (Some(validity), None) => { - let iter = indices - .values() - .iter() - .map(|index| validity.get_bit_unchecked(*index as usize)); - MutableBitmap::from_trusted_len_iter(iter).into() - }, - (Some(validity), _) => { - let iter = indices.iter().map(|x| match x { - Some(index) => validity.get_bit_unchecked(*index as usize), - None => false, - }); - MutableBitmap::from_trusted_len_iter(iter).into() - }, - } -} - -/// # Safety -/// No bound checks -pub unsafe fn take_struct_unchecked(array: &StructArray, indices: &IdxArr) -> StructArray { - let values: Vec> = array - .values() - .iter() - .map(|a| take_unchecked(a.as_ref(), indices)) - .collect(); - let validity = take_validity_unchecked(array.validity(), indices); - StructArray::new(array.data_type().clone(), values, validity) -} - -/// # Safety -/// No bound checks -unsafe fn take_binview_unchecked(arr: &BinaryViewArray, indices: &IdxArr) -> BinaryViewArray { - let (views, validity) = - take_values_and_validity_unchecked(arr.views(), arr.validity(), indices); - - BinaryViewArray::new_unchecked_unknown_md( - arr.data_type().clone(), - views.into(), - arr.data_buffers().clone(), - validity, - Some(arr.total_buffer_len()), - ) - .maybe_gc() -} - -unsafe fn take_values_and_validity_unchecked( - values: &[T], - validity_values: Option<&Bitmap>, - indices: &IdxArr, -) -> (Vec, Option) { - let index_values = indices.values().as_slice(); - - let null_count = validity_values.map(|b| b.unset_bits()).unwrap_or(0); - - // first take the values, these are always needed - let values: Vec = index_values - .iter() - .map(|idx| *values.get_unchecked_release(*idx as usize)) - .collect_trusted(); - - if null_count > 0 { - let validity_values = validity_values.unwrap(); - // the validity buffer we will fill with all valid. And we unset the ones that are null - // in later checks - // this is in the assumption that most values will be valid. - // Maybe we could add another branch based on the null count - let mut validity = MutableBitmap::with_capacity(indices.len()); - validity.extend_constant(indices.len(), true); - let validity_ptr = validity.as_slice().as_ptr() as *mut u8; - - if let Some(validity_indices) = indices.validity().as_ref() { - index_values.iter().enumerate().for_each(|(i, idx)| { - // i is iteration count - // idx is the index that we take from the values array. - let idx = *idx as usize; - if !validity_indices.get_bit_unchecked(i) || !validity_values.get_bit_unchecked(idx) - { - unset_bit_raw(validity_ptr, i); - } - }); - } else { - index_values.iter().enumerate().for_each(|(i, idx)| { - let idx = *idx as usize; - if !validity_values.get_bit_unchecked(idx) { - unset_bit_raw(validity_ptr, i); - } - }); - }; - (values, Some(validity.freeze())) - } else { - (values, indices.validity().cloned()) - } -} - -/// Take kernel for single chunk with nulls and arrow array as index that may have nulls. -/// # Safety -/// caller must ensure indices are in bounds -pub unsafe fn take_primitive_unchecked( - arr: &PrimitiveArray, - indices: &IdxArr, -) -> PrimitiveArray { - let (values, validity) = - take_values_and_validity_unchecked(arr.values(), arr.validity(), indices); - PrimitiveArray::new_unchecked(arr.data_type().clone(), values.into(), validity) -} - -/// Forked and adapted from arrow-rs -/// This is faster because it does no bounds checks and allocates directly into aligned memory -/// -/// Takes/filters a list array's inner data using the offsets of the list array. -/// -/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns -/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2 -/// elements) -/// -/// # Safety -/// No bounds checks -pub unsafe fn take_value_indices_from_list( - list: &ListArray, - indices: &IdxArr, -) -> (IdxArr, Offsets) { - let offsets = list.offsets().as_slice(); - - let mut new_offsets = Vec::with_capacity(indices.len()); - // will likely have at least indices.len values - let mut values = Vec::with_capacity(indices.len()); - let mut current_offset = 0; - // add first offset - new_offsets.push(0); - // compute the value indices, and set offsets accordingly - - let indices_values = indices.values(); - - if !indices.has_validity() { - for i in 0..indices.len() { - let idx = *indices_values.get_unchecked(i) as usize; - let start = *offsets.get_unchecked(idx); - let end = *offsets.get_unchecked(idx + 1); - current_offset += end - start; - new_offsets.push(current_offset); - - let mut curr = start; - - // if start == end, this slot is empty - while curr < end { - values.push(curr as IdxSize); - curr += 1; - } - } - } else { - let validity = indices.validity().expect("should have nulls"); - - for i in 0..indices.len() { - if validity.get_bit_unchecked(i) { - let idx = *indices_values.get_unchecked(i) as usize; - let start = *offsets.get_unchecked(idx); - let end = *offsets.get_unchecked(idx + 1); - current_offset += end - start; - new_offsets.push(current_offset); - - let mut curr = start; - - // if start == end, this slot is empty - while curr < end { - values.push(curr as IdxSize); - curr += 1; - } - } else { - new_offsets.push(current_offset); - } - } - } - - // Safety: - // offsets are monotonically increasing. - unsafe { - ( - IdxArr::from_data_default(values.into(), None), - Offsets::new_unchecked(new_offsets), - ) - } -} diff --git a/crates/polars-arrow/src/legacy/kernels/agg_mean.rs b/crates/polars-arrow/src/legacy/kernels/agg_mean.rs index 50e3981e2d75..3a1a6c01faa5 100644 --- a/crates/polars-arrow/src/legacy/kernels/agg_mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/agg_mean.rs @@ -87,7 +87,7 @@ where let sum = chunks.by_ref().zip(validity_masks.by_ref()).fold( Simd::::splat(0.0), |acc, (chunk, validity_chunk)| { - // safety: exact size chunks + // SAFETY: exact size chunks let chunk: [T; LANES] = unsafe { chunk.try_into().unwrap_unchecked() }; let chunk = Simd::from(chunk).cast_custom::(); diff --git a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs index 1ace23b09d8a..7f36e92cc14e 100644 --- a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs @@ -1,5 +1,5 @@ use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray}; -use crate::legacy::compute::take::take_unchecked; +use crate::compute::take::take_unchecked; use crate::legacy::prelude::*; use crate::legacy::utils::CustomIterTools; @@ -39,7 +39,7 @@ fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray) -> pub fn sub_fixed_size_list_get_literal(arr: &FixedSizeListArray, index: i64) -> ArrayRef { let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index); let values = arr.values(); - // Safety: + // SAFETY: // the indices we generate are in bounds unsafe { take_unchecked(&**values, &take_by) } } @@ -47,7 +47,7 @@ pub fn sub_fixed_size_list_get_literal(arr: &FixedSizeListArray, index: i64) -> pub fn sub_fixed_size_list_get(arr: &FixedSizeListArray, index: &PrimitiveArray) -> ArrayRef { let take_by = sub_fixed_size_list_get_indexes(arr.size(), index); let values = arr.values(); - // Safety: + // SAFETY: // the indices we generate are in bounds unsafe { take_unchecked(&**values, &take_by) } } diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index a4c5723b273c..63fad4bf60c8 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -1,5 +1,5 @@ use crate::array::{ArrayRef, ListArray}; -use crate::legacy::compute::take::take_unchecked; +use crate::compute::take::take_unchecked; use crate::legacy::prelude::*; use crate::legacy::trusted_len::TrustedLenPush; use crate::legacy::utils::CustomIterTools; @@ -68,7 +68,7 @@ fn sublist_get_indexes(arr: &ListArray, index: i64) -> IdxArr { pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { let take_by = sublist_get_indexes(arr, index); let values = arr.values(); - // Safety: + // SAFETY: // the indices we generate are in bounds unsafe { take_unchecked(&**values, &take_by) } } @@ -77,7 +77,7 @@ pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { pub fn array_to_unit_list(array: ArrayRef) -> ListArray { let len = array.len(); let mut offsets = Vec::with_capacity(len + 1); - // Safety: we allocated enough + // SAFETY: we allocated enough unsafe { offsets.push_unchecked(0i64); @@ -86,7 +86,7 @@ pub fn array_to_unit_list(array: ArrayRef) -> ListArray { } }; - // Safety: + // SAFETY: // offsets are monotonically increasing unsafe { let offsets: OffsetsBuffer = Offsets::new_unchecked(offsets).into(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs index 775c8c1f60d5..b5da98336178 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs @@ -1,7 +1,5 @@ -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; use polars_error::polars_ensure; -use super::sum::SumWindow; use super::*; pub struct MeanWindow<'a, T> { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs index efeeb9e183a2..d7368d130c00 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs @@ -1,6 +1,3 @@ -use no_nulls; -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; - use super::*; #[inline] diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 5bd216780d1f..a83631868bbd 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -19,13 +19,13 @@ use super::*; use crate::array::PrimitiveArray; use crate::datatypes::ArrowDataType; use crate::legacy::error::{polars_bail, PolarsResult}; -use crate::legacy::utils::CustomIterTools; use crate::types::NativeType; pub trait RollingAggWindowNoNulls<'a, T: NativeType> { fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self; /// Update and recompute the window + /// /// # Safety /// `start` and `end` must be within the windows bounds unsafe fn update(&mut self, start: usize, end: usize) -> T; @@ -51,7 +51,7 @@ where let out = (0..len) .map(|idx| { let (start, end) = det_offsets_fn(idx, window_size, len); - // safety: + // SAFETY: // we are in bounds unsafe { agg_window.update(start, end) } }) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index b81eeea249fe..a4d590eca931 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -1,5 +1,3 @@ -use std::fmt::Debug; - use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; @@ -66,11 +64,11 @@ impl< let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize; return if top_idx == idx { - // safety + // SAFETY: // we are in bounds unsafe { *vals.get_unchecked_release(idx) } } else { - // safety + // SAFETY: // we are in bounds let (mid, mid_plus_1) = unsafe { ( @@ -93,7 +91,7 @@ impl< }, }; - // safety + // SAFETY: // we are in bounds unsafe { *vals.get_unchecked_release(idx) } } @@ -261,7 +259,6 @@ where #[cfg(test)] mod test { use super::*; - use crate::legacy::kernels::rolling::no_nulls::{rolling_max, rolling_min}; #[test] fn test_rolling_median() { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs index e7ee8209e7ff..5c35c3df5840 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs @@ -1,6 +1,3 @@ -use no_nulls; -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; - use super::*; pub struct SumWindow<'a, T> { @@ -32,7 +29,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> // remove elements that should leave the window let mut recompute_sum = false; for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let leaving_value = self.slice.get_unchecked(idx); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs index 7b3b7d3a5928..564f43642d22 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs @@ -1,7 +1,5 @@ -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; use polars_error::polars_ensure; -use super::mean::MeanWindow; use super::*; pub(super) struct SumSquaredWindow<'a, T> { @@ -39,7 +37,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul< // remove elements that should leave the window let mut recompute_sum = false; for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let leaving_value = self.slice.get_unchecked(idx); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs index 4e2d915b09f1..cbd3c7e981d4 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs @@ -1,5 +1,4 @@ -use super::sum::SumWindow; -use super::{rolling_apply_agg_window, RollingAggWindowNulls, *}; +use super::*; pub struct MeanWindow<'a, T> { sum: SumWindow<'a, T>, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs index 39cfc4940d04..55ea003f5c5a 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs @@ -1,6 +1,3 @@ -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; use crate::array::iterator::NonNullValuesIter; use crate::bitmap::utils::count_zeros; @@ -183,7 +180,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> { // remove elements that should leave the window let mut recompute_extremum = false; for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let valid = self.validity.get_bit_unchecked(idx); if valid { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs index ba0df282a12a..4174c5b49826 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs @@ -46,7 +46,7 @@ where { let len = values.len(); let (start, end) = det_offsets_fn(0, window_size, len); - // Safety; we are in bounds + // SAFETY; we are in bounds let mut agg_window = unsafe { Agg::new(values, validity, start, end, params) }; let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn) @@ -59,7 +59,7 @@ where let out = (0..len) .map(|idx| { let (start, end) = det_offsets_fn(idx, window_size, len); - // safety: + // SAFETY: // we are in bounds let agg = unsafe { agg_window.update(start, end) }; match agg { @@ -67,13 +67,13 @@ where if agg_window.is_valid(min_periods) { val } else { - // safety: we are in bounds + // SAFETY: we are in bounds unsafe { validity.set_unchecked(idx, false) }; T::default() } }, None => { - // safety: we are in bounds + // SAFETY: we are in bounds unsafe { validity.set_unchecked(idx, false) }; T::default() }, @@ -94,7 +94,6 @@ mod test { use crate::array::{Array, Int32Array}; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; - use crate::legacy::kernels::rolling::nulls::mean::rolling_mean; fn get_null_arr() -> PrimitiveArray { // 1, None, -1, 4 diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 5db89c320fb5..f10616547b25 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -155,7 +155,6 @@ mod test { use super::*; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; - use crate::legacy::kernels::rolling::nulls::{rolling_max, rolling_min}; #[test] fn test_rolling_median_nulls() { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs index 72c35f8a7654..876f60187a79 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs @@ -1,6 +1,3 @@ -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; pub struct SumWindow<'a, T> { @@ -66,7 +63,7 @@ impl<'a, T: NativeType + IsFloat + Add + Sub> RollingAgg // remove elements that should leave the window let mut recompute_sum = false; for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let valid = self.validity.get_bit_unchecked(idx); if valid { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs index 2869a335457c..1793fc32a4b7 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs @@ -1,7 +1,3 @@ -use mean::MeanWindow; -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; pub(super) struct SumSquaredWindow<'a, T> { @@ -67,7 +63,7 @@ impl<'a, T: NativeType + IsFloat + Add + Sub + Mul SortedBuf<'a, T> { } /// Update the window position by setting the `start` index and the `end` index. + /// /// # Safety /// The caller must ensure that `start` and `end` are within bounds of `self.slice` /// @@ -37,10 +38,10 @@ impl<'a, T: NativeType> SortedBuf<'a, T> { } else { // remove elements that should leave the window for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let val = self.slice.get_unchecked(idx); - // safety + // SAFETY: // value is present in buf let remove_idx = self .buf @@ -52,7 +53,7 @@ impl<'a, T: NativeType> SortedBuf<'a, T> { // insert elements that enter the window, but insert them sorted for idx in self.last_end..end { - // safety + // SAFETY: // we are in bounds let val = *self.slice.get_unchecked(idx); let insertion_idx = self @@ -120,6 +121,7 @@ impl<'a, T: NativeType> SortedBufNulls<'a, T> { } /// Update the window position by setting the `start` index and the `end` index. + /// /// # Safety /// The caller must ensure that `start` and `end` are within bounds of `self.slice` /// @@ -130,7 +132,7 @@ impl<'a, T: NativeType> SortedBufNulls<'a, T> { } else { // remove elements that should leave the window for idx in self.last_start..start { - // safety + // SAFETY: // we are in bounds let val = if self.validity.get_bit_unchecked(idx) { Some(*self.slice.get_unchecked(idx)) @@ -139,7 +141,7 @@ impl<'a, T: NativeType> SortedBufNulls<'a, T> { None }; - // safety + // SAFETY: // value is present in buf let remove_idx = self .buf @@ -151,7 +153,7 @@ impl<'a, T: NativeType> SortedBufNulls<'a, T> { // insert elements that enter the window, but insert them sorted for idx in self.last_end..end { - // safety + // SAFETY: // we are in bounds let val = if self.validity.get_bit_unchecked(idx) { Some(*self.slice.get_unchecked(idx)) diff --git a/crates/polars-arrow/src/legacy/kernels/set.rs b/crates/polars-arrow/src/legacy/kernels/set.rs index 4fd87905bb74..130d4d956173 100644 --- a/crates/polars-arrow/src/legacy/kernels/set.rs +++ b/crates/polars-arrow/src/legacy/kernels/set.rs @@ -97,10 +97,7 @@ where #[cfg(test)] mod test { - use std::iter::FromIterator; - use super::*; - use crate::array::UInt32Array; #[test] fn test_set_mask() { diff --git a/crates/polars-arrow/src/legacy/kernels/sort_partition.rs b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs index c0ebf5b66404..3021e9a330c0 100644 --- a/crates/polars-arrow/src/legacy/kernels/sort_partition.rs +++ b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs @@ -90,7 +90,7 @@ pub fn partition_to_groups_amortized( let val_ptr = val as *const T; let first_ptr = first as *const T; - // Safety + // SAFETY: // all pointers suffice the invariants let len = unsafe { val_ptr.offset_from(first_ptr) } as IdxSize; out.push([first_idx, len]); diff --git a/crates/polars-arrow/src/legacy/utils.rs b/crates/polars-arrow/src/legacy/utils.rs index 502b03d6087f..316b8ee66bd7 100644 --- a/crates/polars-arrow/src/legacy/utils.rs +++ b/crates/polars-arrow/src/legacy/utils.rs @@ -1,3 +1,5 @@ +use std::borrow::Borrow; + use crate::array::PrimitiveArray; use crate::bitmap::MutableBitmap; use crate::datatypes::ArrowDataType; @@ -11,6 +13,7 @@ pub trait CustomIterTools: Iterator { /// /// # Safety /// The given length must be correct. + #[inline] unsafe fn trust_my_length(self, length: usize) -> TrustMyLength where Self: Sized, @@ -56,6 +59,15 @@ pub trait CustomIterTools: Iterator { } Some(start) } + + fn contains(&mut self, query: &Q) -> bool + where + Self: Sized, + Self::Item: Borrow, + Q: PartialEq, + { + self.any(|x| x.borrow() == query) + } } pub trait CustomIterToolsSized: Iterator + Sized {} diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index b76e1946d7ce..ae8f568b3eb1 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -141,6 +141,7 @@ impl Offsets { } /// Returns [`Offsets`] assuming that `offsets` fulfills its invariants + /// /// # Safety /// This is safe iff the invariants of this struct are guaranteed in `offsets`. #[inline] @@ -168,6 +169,7 @@ impl Offsets { } /// Returns a range (start, end) corresponding to the position `index` + /// /// # Safety /// `index` must be `< self.len()` #[inline] @@ -377,7 +379,7 @@ impl OffsetsBuffer { pub fn into_mut(self) -> either::Either> { self.0 .into_mut() - // Safety: Offsets and OffsetsBuffer share invariants + // SAFETY: Offsets and OffsetsBuffer share invariants .map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) }) .map_left(Self) } @@ -441,6 +443,7 @@ impl OffsetsBuffer { } /// Returns a range (start, end) corresponding to the position `index` + /// /// # Safety /// `index` must be `< self.len()` #[inline] @@ -462,6 +465,7 @@ impl OffsetsBuffer { } /// Slices this [`OffsetsBuffer`] starting at `offset`. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] diff --git a/crates/polars-arrow/src/pushable.rs b/crates/polars-arrow/src/pushable.rs index db71d8726a8a..12d04ebdcf76 100644 --- a/crates/polars-arrow/src/pushable.rs +++ b/crates/polars-arrow/src/pushable.rs @@ -15,6 +15,7 @@ pub trait Pushable: Sized + Default { fn len(&self) -> usize; fn push_null(&mut self); fn extend_constant(&mut self, additional: usize, value: T); + fn extend_null_constant(&mut self, additional: usize); } impl Pushable for MutableBitmap { @@ -41,6 +42,11 @@ impl Pushable for MutableBitmap { fn extend_constant(&mut self, additional: usize, value: bool) { self.extend_constant(additional, value) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, false) + } } impl Pushable for Vec { @@ -67,6 +73,11 @@ impl Pushable for Vec { fn extend_constant(&mut self, additional: usize, value: T) { self.resize(self.len() + additional, value); } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, T::default()) + } } impl Pushable for Offsets { fn reserve(&mut self, additional: usize) { @@ -91,6 +102,11 @@ impl Pushable for Offsets { fn extend_constant(&mut self, additional: usize, _: usize) { self.extend_constant(additional) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } impl Pushable> for MutablePrimitiveArray { @@ -118,6 +134,11 @@ impl Pushable> for MutablePrimitiveArray { fn extend_constant(&mut self, additional: usize, value: Option) { MutablePrimitiveArray::extend_constant(self, additional, value) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + MutablePrimitiveArray::extend_constant(self, additional, None) + } } impl Pushable<&T> for MutableBinaryViewArray { @@ -157,4 +178,9 @@ impl Pushable<&T> for MutableBinaryViewArray { bitmap.extend_constant(remaining, true) } } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_null(additional); + } } diff --git a/crates/polars-arrow/src/scalar/equal.rs b/crates/polars-arrow/src/scalar/equal.rs index 0d02ca9c9d61..c18d63455913 100644 --- a/crates/polars-arrow/src/scalar/equal.rs +++ b/crates/polars-arrow/src/scalar/equal.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use super::*; -use crate::datatypes::PhysicalType; use crate::{match_integer_type, with_match_primitive_type}; impl PartialEq for dyn Scalar + '_ { @@ -53,6 +52,7 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), Union => dyn_eq!(UnionScalar, lhs, rhs), Map => dyn_eq!(MapScalar, lhs, rhs), + Utf8View => dyn_eq!(BinaryViewScalar, lhs, rhs), _ => unimplemented!(), } } diff --git a/crates/polars-arrow/src/types/index.rs b/crates/polars-arrow/src/types/index.rs index 0aedea008fa3..83299a76980f 100644 --- a/crates/polars-arrow/src/types/index.rs +++ b/crates/polars-arrow/src/types/index.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use super::NativeType; use crate::trusted_len::TrustedLen; @@ -99,5 +97,7 @@ impl Iterator for IndexRange { } } -/// Safety: a range is always of known length +/// # Safety +/// +/// A range is always of known length. unsafe impl TrustedLen for IndexRange {} diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index 45d8d7cb665f..95966004b848 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::ops::Neg; use std::panic::RefUnwindSafe; diff --git a/crates/polars-arrow/src/types/simd/mod.rs b/crates/polars-arrow/src/types/simd/mod.rs index d906c9d25e95..2666abe2ba2c 100644 --- a/crates/polars-arrow/src/types/simd/mod.rs +++ b/crates/polars-arrow/src/types/simd/mod.rs @@ -123,6 +123,7 @@ macro_rules! native_simd { }; } +#[cfg(not(feature = "simd"))] pub(super) use native_simd; // Types do not have specific intrinsics and thus SIMD can't be specialized. diff --git a/crates/polars-arrow/src/types/simd/native.rs b/crates/polars-arrow/src/types/simd/native.rs index af31b8b26bc0..f0cb5436f4f3 100644 --- a/crates/polars-arrow/src/types/simd/native.rs +++ b/crates/polars-arrow/src/types/simd/native.rs @@ -1,7 +1,4 @@ -use std::convert::TryInto; - use super::*; -use crate::types::BitChunkIter; native_simd!(u8x64, u8, 64, u64); native_simd!(u16x32, u16, 32, u32); diff --git a/crates/polars-arrow/tests/it/ffi/mod.rs b/crates/polars-arrow/tests/it/ffi/mod.rs deleted file mode 100644 index 36d8589f579b..000000000000 --- a/crates/polars-arrow/tests/it/ffi/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod data; diff --git a/crates/polars-arrow/tests/it/main.rs b/crates/polars-arrow/tests/it/main.rs deleted file mode 100644 index a21dad004e51..000000000000 --- a/crates/polars-arrow/tests/it/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod ffi; -#[cfg(feature = "io_ipc_compression")] -mod io; diff --git a/crates/polars-compute/Cargo.toml b/crates/polars-compute/Cargo.toml index a13202aa602d..14be8be65f80 100644 --- a/crates/polars-compute/Cargo.toml +++ b/crates/polars-compute/Cargo.toml @@ -11,9 +11,11 @@ description = "Private compute kernels for the Polars DataFrame library" [dependencies] arrow = { workspace = true } bytemuck = { workspace = true } +either = { workspace = true } num-traits = { workspace = true } polars-error = { workspace = true } polars-utils = { workspace = true } +strength_reduce = { workspace = true } [build-dependencies] version_check = { workspace = true } diff --git a/crates/polars-compute/src/arithmetic/float.rs b/crates/polars-compute/src/arithmetic/float.rs new file mode 100644 index 000000000000..3b66e91fdc55 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/float.rs @@ -0,0 +1,115 @@ +use arrow::array::PrimitiveArray as PArr; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; + +macro_rules! impl_float_arith_kernel { + ($T:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = $T; + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| -x) + } + + fn prim_wrapping_add(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l + r) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l - r) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l * r) + } + + fn prim_wrapping_floor_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| (l / r).floor()) + } + + fn prim_wrapping_trunc_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| (l / r).trunc()) + } + + fn prim_wrapping_mod(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l - r * (l / r).floor()) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0.0 { + return lhs; + } + prim_unary_values(lhs, |x| x + rhs) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0.0 { + return lhs; + } + Self::prim_wrapping_add_scalar(lhs, -rhs) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0.0 { + Self::prim_wrapping_neg(rhs) + } else { + prim_unary_values(rhs, |x| lhs - x) + } + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + // No optimization for multiplication by zero, would invalidate NaNs/infinities. + if rhs == 1.0 { + lhs + } else if rhs == -1.0 { + Self::prim_wrapping_neg(lhs) + } else { + prim_unary_values(lhs, |x| x * rhs) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| (x * inv).floor()) + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| (lhs / x).floor()) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| (x * inv).trunc()) + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| (lhs / x).trunc()) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| x - rhs * (x * inv).floor()) + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs - x * (lhs / x).floor()) + } + + fn prim_true_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr { + prim_binary_values(lhs, rhs, |l, r| l / r) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + Self::prim_wrapping_mul_scalar(lhs, 1.0 / rhs) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs / x) + } + } + }; +} + +impl_float_arith_kernel!(f32); +impl_float_arith_kernel!(f64); diff --git a/crates/polars-compute/src/arithmetic/mod.rs b/crates/polars-compute/src/arithmetic/mod.rs new file mode 100644 index 000000000000..1724142b6a7f --- /dev/null +++ b/crates/polars-compute/src/arithmetic/mod.rs @@ -0,0 +1,142 @@ +use std::any::TypeId; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; + +// Low-level comparison kernel. +pub trait ArithmeticKernel: Sized + Array { + type Scalar; + type TrueDivT: NativeType; + + fn wrapping_neg(self) -> Self; + fn wrapping_add(self, rhs: Self) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; + fn wrapping_mul(self, rhs: Self) -> Self; + fn wrapping_floor_div(self, rhs: Self) -> Self; + fn wrapping_trunc_div(self, rhs: Self) -> Self; + fn wrapping_mod(self, rhs: Self) -> Self; + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + + fn true_div(self, rhs: Self) -> PrimitiveArray; + fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray; + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray; + + // TODO: remove these. + // These are flooring division for integer types, true division for floating point types. + fn legacy_div(self, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = self.true_div(rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + self.wrapping_floor_div(rhs) + } + } + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = self.true_div_scalar(rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + self.wrapping_floor_div_scalar(rhs) + } + } + + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs) + } + } +} + +// Proxy trait so one can bound T: HasPrimitiveArithmeticKernel. Sadly Rust +// doesn't support adding supertraits for other types. +#[allow(private_bounds)] +pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {} +impl HasPrimitiveArithmeticKernel for T {} + +use PrimitiveArray as PArr; + +#[doc(hidden)] +pub trait PrimitiveArithmeticKernelImpl: NativeType { + type TrueDivT: NativeType; + + fn prim_wrapping_neg(lhs: PArr) -> PArr; + fn prim_wrapping_add(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_sub(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_mul(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_floor_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_trunc_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_mod(lhs: PArr, rhs: PArr) -> PArr; + + fn prim_wrapping_add_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_sub_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_mul_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_floor_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_trunc_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_mod_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + + fn prim_true_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_true_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; +} + +#[rustfmt::skip] +impl ArithmeticKernel for PrimitiveArray { + type Scalar = T; + type TrueDivT = T::TrueDivT; + + fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) } + fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) } + fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) } + fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) } + fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) } + fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) } + fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) } + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) } + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) } + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) } + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) } + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) } + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) } + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) } + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) } + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) } + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) } + + fn true_div(self, rhs: Self) -> PrimitiveArray { T::prim_true_div(self, rhs) } + fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray { T::prim_true_div_scalar(self, rhs) } + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray { T::prim_true_div_scalar_lhs(lhs, rhs) } +} + +mod float; +mod signed; +mod unsigned; diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs new file mode 100644 index 000000000000..94a13ba10394 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -0,0 +1,234 @@ +use arrow::array::{PrimitiveArray as PArr, StaticArray}; +use arrow::compute::utils::{combine_validities_and, combine_validities_and3}; +use polars_utils::floor_divmod::FloorDivMod; +use strength_reduce::*; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; +use crate::comparisons::TotalOrdKernel; + +macro_rules! impl_signed_arith_kernel { + ($T:ty, $StrRed:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = f64; + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_neg()) + } + + fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_add(b)) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b)) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b)) + } + + fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).0); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |lhs, rhs| { + if rhs != 0 { + lhs.wrapping_div(rhs) + } else { + 0 + } + }); + ret.with_validity(valid) + } + + fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).1); + ret.with_validity(valid) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_add(rhs)) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg()) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs.wrapping_sub(x)) + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let scalar_u = rhs.unsigned_abs(); + if rhs == 0 { + lhs.fill_with(0) + } else if rhs == 1 { + lhs + } else if scalar_u & (scalar_u - 1) == 0 { + // Power of two. + let shift = scalar_u.trailing_zeros(); + if rhs > 0 { + prim_unary_values(lhs, |x| x << shift) + } else { + prim_unary_values(lhs, |x| (x << shift).wrapping_neg()) + } + } else { + prim_unary_values(lhs, |x| x.wrapping_mul(rhs)) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.data_type().clone()) + } else if rhs == -1 { + Self::prim_wrapping_neg(lhs) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs.unsigned_abs()); + prim_unary_values(lhs, |x| { + let (quot, rem) = <$StrRed>::div_rem(x.unsigned_abs(), red); + if (x < 0) != (rhs < 0) { + // Different signs: result should be negative. + // Since we handled rhs.abs() <= 1, quot fits. + let mut ret = -(quot as $T); + if rem != 0 { + // Division had remainder, subtract 1 to floor to + // negative infinity, as we truncated to zero. + ret -= 1; + } + ret + } else { + quot as $T + } + }) + } + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0 { + return rhs.fill_with(0); + } + + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.data_type().clone()) + } else if rhs == -1 { + Self::prim_wrapping_neg(lhs) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs.unsigned_abs()); + prim_unary_values(lhs, |x| { + let quot = x.unsigned_abs() / red; + if (x < 0) != (rhs < 0) { + // Different signs: result should be negative. + -(quot as $T) + } else { + quot as $T + } + }) + } + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0 { + return rhs.fill_with(0); + } + + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.data_type().clone()) + } else if rhs == -1 || rhs == 1 { + lhs.fill_with(0) + } else { + let scalar_u = rhs.unsigned_abs(); + let red = <$StrRed>::new(scalar_u); + prim_unary_values(lhs, |x| { + // Remainder fits in signed type after reduction. + // Largest possible modulo -I::MIN, with + // -I::MIN-1 == I::MAX as largest remainder. + let mut rem_u = x.unsigned_abs() % red; + + // Mixed signs: swap direction of remainder. + if rem_u != 0 && (rhs < 0) != (x < 0) { + rem_u = scalar_u - rem_u; + } + + // Remainder should have sign of RHS. + if rhs < 0 { + -(rem_u as $T) + } else { + rem_u as $T + } + }) + } + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0 { + return rhs.fill_with(0); + } + + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1); + ret.with_validity(valid) + } + + fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr { + prim_binary_values(lhs, other, |a, b| a as f64 / b as f64) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + let inv = 1.0 / rhs as f64; + prim_unary_values(lhs, |x| x as f64 * inv) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs as f64 / x as f64) + } + } + }; +} + +impl_signed_arith_kernel!(i8, StrengthReducedU8); +impl_signed_arith_kernel!(i16, StrengthReducedU16); +impl_signed_arith_kernel!(i32, StrengthReducedU32); +impl_signed_arith_kernel!(i64, StrengthReducedU64); +impl_signed_arith_kernel!(i128, StrengthReducedU128); diff --git a/crates/polars-compute/src/arithmetic/unsigned.rs b/crates/polars-compute/src/arithmetic/unsigned.rs new file mode 100644 index 000000000000..67023ef07dd8 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/unsigned.rs @@ -0,0 +1,154 @@ +use arrow::array::{PrimitiveArray as PArr, StaticArray}; +use arrow::compute::utils::{combine_validities_and, combine_validities_and3}; +use strength_reduce::*; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; +use crate::comparisons::TotalOrdKernel; + +macro_rules! impl_unsigned_arith_kernel { + ($T:ty, $StrRed:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = f64; + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_neg()) + } + + fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_add(b)) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b)) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b)) + } + + fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |a, b| if b != 0 { a / b } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + Self::prim_wrapping_floor_div(lhs, rhs) + } + + fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |a, b| if b != 0 { a % b } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_add(rhs)) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg()) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs.wrapping_sub(x)) + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + lhs.fill_with(0) + } else if rhs == 1 { + lhs + } else if rhs & (rhs - 1) == 0 { + // Power of two. + let shift = rhs.trailing_zeros(); + prim_unary_values(lhs, |x| x << shift) + } else { + prim_unary_values(lhs, |x| x.wrapping_mul(rhs)) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.data_type().clone()) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs); + prim_unary_values(lhs, |x| x / red) + } + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0 { + return rhs.fill_with(0); + } + + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_floor_div_scalar(lhs, rhs) + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + Self::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.data_type().clone()) + } else if rhs == 1 { + lhs.fill_with(0) + } else { + let red = <$StrRed>::new(rhs); + prim_unary_values(lhs, |x| x % red) + } + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0 { + return rhs.fill_with(0); + } + + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 }); + ret.with_validity(valid) + } + + fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr { + prim_binary_values(lhs, other, |a, b| a as f64 / b as f64) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + let inv = 1.0 / rhs as f64; + prim_unary_values(lhs, |x| x as f64 * inv) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs as f64 / x as f64) + } + } + }; +} + +impl_unsigned_arith_kernel!(u8, StrengthReducedU8); +impl_unsigned_arith_kernel!(u16, StrengthReducedU16); +impl_unsigned_arith_kernel!(u32, StrengthReducedU32); +impl_unsigned_arith_kernel!(u64, StrengthReducedU64); +impl_unsigned_arith_kernel!(u128, StrengthReducedU128); diff --git a/crates/polars-compute/src/arity.rs b/crates/polars-compute/src/arity.rs new file mode 100644 index 000000000000..33c8b0eb0584 --- /dev/null +++ b/crates/polars-compute/src/arity.rs @@ -0,0 +1,132 @@ +use arrow::array::PrimitiveArray; +use arrow::compute::utils::combine_validities_and; +use arrow::types::NativeType; + +/// To reduce codegen we use these helpers where the input and output arrays +/// may overlap. These are marked to never be inlined, this way only a single +/// unrolled kernel gets generated, even if we call it in multiple ways. +/// +/// # Safety +/// - arr must point to a readable slice of length len. +/// - out must point to a writable slice of length len. +#[inline(never)] +unsafe fn ptr_apply_unary_kernel O>( + arr: *const I, + out: *mut O, + len: usize, + op: F, +) { + for i in 0..len { + let ret = op(arr.add(i).read()); + out.add(i).write(ret); + } +} + +/// # Safety +/// - left must point to a readable slice of length len. +/// - right must point to a readable slice of length len. +/// - out must point to a writable slice of length len. +#[inline(never)] +unsafe fn ptr_apply_binary_kernel O>( + left: *const L, + right: *const R, + out: *mut O, + len: usize, + op: F, +) { + for i in 0..len { + let ret = op(left.add(i).read(), right.add(i).read()); + out.add(i).write(ret); + } +} + +/// Applies a function to all the values (regardless of nullability). +/// +/// May reuse the memory of the array if possible. +pub fn prim_unary_values(mut arr: PrimitiveArray, op: F) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let len = arr.len(); + + // Reuse memory if possible. + if std::mem::size_of::() == std::mem::size_of::() + && std::mem::align_of::() == std::mem::align_of::() + { + if let Some(values) = arr.get_mut_values() { + let ptr = values.as_mut_ptr(); + // SAFETY: checked same size & alignment I/O, NativeType is always Pod. + unsafe { ptr_apply_unary_kernel(ptr, ptr as *mut O, len, op) } + return arr.transmute::(); + } + } + + let mut out = Vec::with_capacity(len); + unsafe { + // SAFETY: checked pointers point to slices of length len. + ptr_apply_unary_kernel(arr.values().as_ptr(), out.as_mut_ptr(), len, op); + out.set_len(len); + } + PrimitiveArray::from_vec(out).with_validity(arr.take_validity()) +} + +/// Apply a binary function to all the values (regardless of nullability) +/// in (lhs, rhs). Combines the validities with a bitand. +/// +/// May reuse the memory of one of its arguments if possible. +pub fn prim_binary_values( + mut lhs: PrimitiveArray, + mut rhs: PrimitiveArray, + op: F, +) -> PrimitiveArray +where + L: NativeType, + R: NativeType, + O: NativeType, + F: Fn(L, R) -> O, +{ + assert_eq!(lhs.len(), rhs.len()); + let len = lhs.len(); + + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + // Reuse memory if possible. + if std::mem::size_of::() == std::mem::size_of::() + && std::mem::align_of::() == std::mem::align_of::() + { + if let Some(lv) = lhs.get_mut_values() { + let lp = lv.as_mut_ptr(); + let rp = rhs.values().as_ptr(); + unsafe { + // SAFETY: checked same size & alignment L/O, NativeType is always Pod. + ptr_apply_binary_kernel(lp, rp, lp as *mut O, len, op); + } + return lhs.transmute::().with_validity(validity); + } + } + if std::mem::size_of::() == std::mem::size_of::() + && std::mem::align_of::() == std::mem::align_of::() + { + if let Some(rv) = rhs.get_mut_values() { + let lp = lhs.values().as_ptr(); + let rp = rv.as_mut_ptr(); + unsafe { + // SAFETY: checked same size & alignment R/O, NativeType is always Pod. + ptr_apply_binary_kernel(lp, rp, rp as *mut O, len, op); + } + return rhs.transmute::().with_validity(validity); + } + } + + let mut out = Vec::with_capacity(len); + unsafe { + // SAFETY: checked pointers point to slices of length len. + let lp = lhs.values().as_ptr(); + let rp = rhs.values().as_ptr(); + ptr_apply_binary_kernel(lp, rp, out.as_mut_ptr(), len, op); + out.set_len(len); + } + PrimitiveArray::from_vec(out).with_validity(validity) +} diff --git a/crates/polars-compute/src/filter/boolean.rs b/crates/polars-compute/src/filter/boolean.rs index 0050477ac5d7..4030d819a59b 100644 --- a/crates/polars-compute/src/filter/boolean.rs +++ b/crates/polars-compute/src/filter/boolean.rs @@ -93,7 +93,7 @@ where unsafe { new.extend_from_slice_unchecked(chunk.to_ne_bytes().as_ref(), 0, size); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() new_validity.extend_from_slice_unchecked( validity_chunk.to_ne_bytes().as_ref(), 0, diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs index 336009b8a233..085074efbe71 100644 --- a/crates/polars-compute/src/filter/primitive.rs +++ b/crates/polars-compute/src/filter/primitive.rs @@ -112,7 +112,7 @@ where std::ptr::copy(chunk.as_ptr(), dst, size); dst = dst.add(size); - // safety: invariant offset + length <= slice.len() + // SAFETY: invariant offset + length <= slice.len() new_validity.extend_from_slice_unchecked( validity_chunk.to_ne_bytes().as_ref(), 0, diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index 759cd67a82b3..0cd894d38013 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -1,5 +1,8 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] +pub mod arithmetic; pub mod comparisons; pub mod filter; pub mod min_max; + +pub mod arity; diff --git a/crates/polars-compute/src/min_max/mod.rs b/crates/polars-compute/src/min_max/mod.rs index 0df4f735a727..5278cb9b1dd7 100644 --- a/crates/polars-compute/src/min_max/mod.rs +++ b/crates/polars-compute/src/min_max/mod.rs @@ -8,8 +8,18 @@ pub trait MinMaxKernel { fn min_ignore_nan_kernel(&self) -> Option>; fn max_ignore_nan_kernel(&self) -> Option>; + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + Some((self.min_ignore_nan_kernel()?, self.max_ignore_nan_kernel()?)) + } + fn min_propagate_nan_kernel(&self) -> Option>; fn max_propagate_nan_kernel(&self) -> Option>; + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + Some(( + self.min_propagate_nan_kernel()?, + self.max_propagate_nan_kernel()?, + )) + } } // Trait to enable the scalar blanket implementation. diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 052fe4e890d8..1915eb71957b 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -94,7 +94,6 @@ dataframe_arithmetic = [] product = [] unique_counts = [] partition_by = [] -chunked_ids = [] describe = [] timezones = ["chrono-tz", "arrow/chrono-tz", "arrow/timezones"] dynamic_group_by = ["dtype-datetime", "dtype-date"] @@ -144,7 +143,6 @@ docs-selection = [ "dataframe_arithmetic", "product", "describe", - "chunked_ids", "partition_by", "algorithm_group_by", ] diff --git a/crates/polars-core/src/chunked_array/arithmetic/decimal.rs b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs index c3d684475ea0..3e5067656621 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/decimal.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -1,112 +1,13 @@ -use arrow::legacy::compute::arithmetics::decimal; - use super::*; -use crate::prelude::DecimalChunked; -use crate::utils::align_chunks_binary; - -// TODO: remove -impl ArrayArithmetics for i128 { - fn add(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!() - } - - fn sub(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!() - } - - fn mul(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!() - } - - fn div(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!() - } - - fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - unimplemented!() - } - - fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } - - fn rem_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } -} - -impl DecimalChunked { - fn arithmetic_helper( - &self, - rhs: &DecimalChunked, - kernel: Kernel, - operation_lhs: ScalarKernelLhs, - operation_rhs: ScalarKernelRhs, - ) -> PolarsResult - where - Kernel: - Fn(&PrimitiveArray, &PrimitiveArray) -> PolarsResult>, - ScalarKernelLhs: Fn(&PrimitiveArray, i128) -> PolarsResult>, - ScalarKernelRhs: Fn(i128, &PrimitiveArray) -> PolarsResult>, - { - let lhs = self; - - let mut ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (lhs, rhs) = align_chunks_binary(lhs, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(lhs, rhs)| kernel(lhs, rhs).map(|a| Box::new(a) as ArrayRef)) - .collect::>()?; - unsafe { lhs.copy_with_chunks(chunks, false, false) } - }, - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs_val) => { - let chunks = lhs - .downcast_iter() - .map(|lhs| operation_lhs(lhs, rhs_val).map(|a| Box::new(a) as ArrayRef)) - .collect::>()?; - unsafe { lhs.copy_with_chunks(chunks, false, false) } - }, - } - }, - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs_val) => { - let chunks = rhs - .downcast_iter() - .map(|rhs| operation_rhs(lhs_val, rhs).map(|a| Box::new(a) as ArrayRef)) - .collect::>()?; - unsafe { lhs.copy_with_chunks(chunks, false, false) } - }, - } - }, - _ => { - polars_bail!(ComputeError: "Cannot apply operation on arrays of different lengths") - }, - }; - ca.rename(lhs.name()); - Ok(ca.into_decimal_unchecked(self.precision(), self.scale())) - } -} impl Add for &DecimalChunked { type Output = PolarsResult; fn add(self, rhs: Self) -> Self::Output { - self.arithmetic_helper( - rhs, - decimal::add, - |lhs, rhs_val| decimal::add_scalar(lhs, rhs_val, &rhs.dtype().to_arrow(true)), - |lhs_val, rhs| decimal::add_scalar(rhs, lhs_val, &self.dtype().to_arrow(true)), - ) + let scale = self.scale().max(rhs.scale()); + let lhs = self.to_scale(scale)?; + let rhs = rhs.to_scale(scale)?; + Ok((&lhs.0 + &rhs.0).into_decimal_unchecked(None, scale)) } } @@ -114,12 +15,10 @@ impl Sub for &DecimalChunked { type Output = PolarsResult; fn sub(self, rhs: Self) -> Self::Output { - self.arithmetic_helper( - rhs, - decimal::sub, - decimal::sub_scalar, - decimal::sub_scalar_swapped, - ) + let scale = self.scale().max(rhs.scale()); + let lhs = self.to_scale(scale)?; + let rhs = rhs.to_scale(scale)?; + Ok((&lhs.0 - &rhs.0).into_decimal_unchecked(None, scale)) } } @@ -127,12 +26,8 @@ impl Mul for &DecimalChunked { type Output = PolarsResult; fn mul(self, rhs: Self) -> Self::Output { - self.arithmetic_helper( - rhs, - decimal::mul, - |lhs, rhs_val| decimal::mul_scalar(lhs, rhs_val, &rhs.dtype().to_arrow(true)), - |lhs_val, rhs| decimal::mul_scalar(rhs, lhs_val, &self.dtype().to_arrow(true)), - ) + let scale = self.scale() + rhs.scale(); + Ok((&self.0 * &rhs.0).into_decimal_unchecked(None, scale)) } } @@ -140,11 +35,9 @@ impl Div for &DecimalChunked { type Output = PolarsResult; fn div(self, rhs: Self) -> Self::Output { - self.arithmetic_helper( - rhs, - decimal::div, - |lhs, rhs_val| decimal::div_scalar(lhs, rhs_val, &rhs.dtype().to_arrow(true)), - |lhs_val, rhs| decimal::div_scalar_swapped(lhs_val, &self.dtype().to_arrow(true), rhs), - ) + // Follow postgres and MySQL adding a fixed scale increment of 4 + let scale = self.scale() + 4; + let lhs = self.to_scale(scale + rhs.scale())?; + Ok((&lhs.0 / &rhs.0).into_decimal_unchecked(None, scale)) } } diff --git a/crates/polars-core/src/chunked_array/arithmetic/mod.rs b/crates/polars-core/src/chunked_array/arithmetic/mod.rs index 306fa2e6ba3a..874ed2033097 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/mod.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/mod.rs @@ -5,64 +5,11 @@ mod numeric; use std::ops::{Add, Div, Mul, Rem, Sub}; -use arrow::array::PrimitiveArray; -use arrow::compute::arithmetics::basic; -use arrow::compute::arity_assign; use arrow::compute::utils::combine_validities_and; -use arrow::types::NativeType; -use num_traits::{Num, NumCast, ToPrimitive, Zero}; -pub(super) use numeric::arithmetic_helper; +use num_traits::{Num, NumCast, ToPrimitive}; +pub use numeric::ArithmeticChunked; use crate::prelude::*; -use crate::series::IsSorted; -use crate::utils::align_chunks_binary_owned; - -pub trait ArrayArithmetics -where - Self: NativeType, -{ - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; -} - -macro_rules! native_array_arithmetics { - ($ty: ty) => { - impl ArrayArithmetics for $ty - { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::add(lhs, rhs) - } - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::sub(lhs, rhs) - } - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::mul(lhs, rhs) - } - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::div(lhs, rhs) - } - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::div_scalar(lhs, rhs) - } - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::rem(lhs, rhs) - } - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::rem_scalar(lhs, rhs) - } - } - }; - ($($ty:ty),*) => { - $(native_array_arithmetics!($ty);)* - } -} - -native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); #[inline] fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { diff --git a/crates/polars-core/src/chunked_array/arithmetic/numeric.rs b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs index 7729ca5f75db..4c996761cf5e 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/numeric.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -1,388 +1,413 @@ -use super::*; +use polars_compute::arithmetic::ArithmeticKernel; -pub(crate) fn arithmetic_helper( - lhs: &ChunkedArray, - rhs: &ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, - F: Fn(T::Native, T::Native) -> T::Native, -{ - let mut ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => arity::binary(lhs, rhs, |lhs, rhs| kernel(lhs, rhs)), - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => lhs.apply_values(|lhs| operation(lhs, rhs)), - } - }, - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs) => rhs.apply_values(|rhs| operation(lhs, rhs)), +use super::*; +use crate::chunked_array::arity::{ + apply_binary_kernel_broadcast, apply_binary_kernel_broadcast_owned, unary_kernel, + unary_kernel_owned, +}; + +macro_rules! impl_op_overload { + ($op: ident, $trait_method: ident, $ca_method: ident, $ca_method_scalar: ident) => { + impl $op for ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: Self) -> Self::Output { + ArithmeticChunked::$ca_method(self, rhs) } - }, - _ => panic!("Cannot apply operation on arrays of different lengths"), - }; - ca.rename(lhs.name()); - ca -} + } + + impl $op for &ChunkedArray { + type Output = ChunkedArray; -/// This assigns to the owned buffer if the ref count is 1 -fn arithmetic_helper_owned( - mut lhs: ChunkedArray, - mut rhs: ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), - F: Fn(T::Native, T::Native) -> T::Native, -{ - let ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs); - // safety, we do no t change the lengths - unsafe { - lhs.downcast_iter_mut() - .zip(rhs.downcast_iter_mut()) - .for_each(|(lhs, rhs)| kernel(lhs, rhs)); + fn $trait_method(self, rhs: Self) -> Self::Output { + ArithmeticChunked::$ca_method(self, rhs) } - lhs.compute_len(); - lhs.set_sorted_flag(IsSorted::Not); - lhs - }, - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => { - lhs.apply_mut(|lhs| operation(lhs, rhs)); - lhs - }, + } + + // TODO: make this more strict instead of casting. + impl $op for ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + ArithmeticChunked::$ca_method_scalar(self, rhs) } - }, - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs_val) => { - rhs.apply_mut(|rhs| operation(lhs_val, rhs)); - rhs.rename(lhs.name()); - rhs - }, + } + + impl $op for &ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + ArithmeticChunked::$ca_method_scalar(self, rhs) } - }, - _ => panic!("Cannot apply operation on arrays of different lengths"), + } }; - ca } -// Operands on ChunkedArray & ChunkedArray +impl_op_overload!(Add, add, wrapping_add, wrapping_add_scalar); +impl_op_overload!(Sub, sub, wrapping_sub, wrapping_sub_scalar); +impl_op_overload!(Mul, mul, wrapping_mul, wrapping_mul_scalar); +impl_op_overload!(Div, div, legacy_div, legacy_div_scalar); // FIXME: replace this with true division. +impl_op_overload!(Rem, rem, wrapping_mod, wrapping_mod_scalar); + +pub trait ArithmeticChunked { + type Scalar; + type Out; + type TrueDivOut; + + fn wrapping_neg(self) -> Self::Out; + fn wrapping_add(self, rhs: Self) -> Self::Out; + fn wrapping_sub(self, rhs: Self) -> Self::Out; + fn wrapping_mul(self, rhs: Self) -> Self::Out; + fn wrapping_floor_div(self, rhs: Self) -> Self::Out; + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out; + fn wrapping_mod(self, rhs: Self) -> Self::Out; + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + + fn true_div(self, rhs: Self) -> Self::TrueDivOut; + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut; + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut; + + // TODO: remove these. + // These are flooring division for integer types, true division for floating point types. + fn legacy_div(self, rhs: Self) -> Self::Out; + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; +} + +impl ArithmeticChunked for ChunkedArray { + type Scalar = T::Native; + type Out = ChunkedArray; + type TrueDivOut = ChunkedArray<::TrueDivPolarsType>; -impl Add for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; + fn wrapping_neg(self) -> Self::Out { + unary_kernel_owned(self, ArithmeticKernel::wrapping_neg) + } - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper( + fn wrapping_add(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - ::add, - |lhs, rhs| lhs + rhs, + ArithmeticKernel::wrapping_add, + |l, r| ArithmeticKernel::wrapping_add_scalar(r, l), + ArithmeticKernel::wrapping_add_scalar, ) } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper( + fn wrapping_sub(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - ::div, - |lhs, rhs| lhs / rhs, + ArithmeticKernel::wrapping_sub, + ArithmeticKernel::wrapping_sub_scalar_lhs, + ArithmeticKernel::wrapping_sub_scalar, ) } -} -impl Mul for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper( + fn wrapping_mul(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - ::mul, - |lhs, rhs| lhs * rhs, + ArithmeticKernel::wrapping_mul, + |l, r| ArithmeticKernel::wrapping_mul_scalar(r, l), + ArithmeticKernel::wrapping_mul_scalar, ) } -} - -impl Rem for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - fn rem(self, rhs: Self) -> Self::Output { - arithmetic_helper( + fn wrapping_floor_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - ::rem, - |lhs, rhs| lhs % rhs, + ArithmeticKernel::wrapping_floor_div, + ArithmeticKernel::wrapping_floor_div_scalar_lhs, + ArithmeticKernel::wrapping_floor_div_scalar, ) } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper( + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - ::sub, - |lhs, rhs| lhs - rhs, + ArithmeticKernel::wrapping_trunc_div, + ArithmeticKernel::wrapping_trunc_div_scalar_lhs, + ArithmeticKernel::wrapping_trunc_div_scalar, ) } -} -impl Add for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( + fn wrapping_mod(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - |a, b| arity_assign::binary(a, b, |a, b| a + b), - |lhs, rhs| lhs + rhs, + ArithmeticKernel::wrapping_mod, + ArithmeticKernel::wrapping_mod_scalar_lhs, + ArithmeticKernel::wrapping_mod_scalar, ) } -} -impl Div for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_add_scalar(a, rhs)) + } + + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_sub_scalar(a, rhs)) + } + + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::wrapping_sub_scalar_lhs(lhs, a)) + } + + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_mul_scalar(a, rhs)) + } + + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| { + ArithmeticKernel::wrapping_floor_div_scalar(a, rhs) + }) + } + + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, a) + }) + } - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar(a, rhs) + }) + } + + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar_lhs(lhs, a) + }) + } + + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_mod_scalar(a, rhs)) + } + + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::wrapping_mod_scalar_lhs(lhs, a)) + } + + fn true_div(self, rhs: Self) -> Self::TrueDivOut { + apply_binary_kernel_broadcast_owned( self, rhs, - |a, b| arity_assign::binary(a, b, |a, b| a / b), - |lhs, rhs| lhs / rhs, + ArithmeticKernel::true_div, + ArithmeticKernel::true_div_scalar_lhs, + ArithmeticKernel::true_div_scalar, ) } -} -impl Mul for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut { + unary_kernel_owned(self, |a| ArithmeticKernel::true_div_scalar(a, rhs)) + } + + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut { + unary_kernel_owned(rhs, |a| ArithmeticKernel::true_div_scalar_lhs(lhs, a)) + } - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( + fn legacy_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( self, rhs, - |a, b| arity_assign::binary(a, b, |a, b| a * b), - |lhs, rhs| lhs * rhs, + ArithmeticKernel::legacy_div, + ArithmeticKernel::legacy_div_scalar_lhs, + ArithmeticKernel::legacy_div_scalar, ) } + + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::legacy_div_scalar(a, rhs)) + } + + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::legacy_div_scalar_lhs(lhs, a)) + } } -impl Sub for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; +impl ArithmeticChunked for &ChunkedArray { + type Scalar = T::Native; + type Out = ChunkedArray; + type TrueDivOut = ChunkedArray<::TrueDivPolarsType>; - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( + fn wrapping_neg(self) -> Self::Out { + unary_kernel(self, |a| ArithmeticKernel::wrapping_neg(a.clone())) + } + + fn wrapping_add(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( self, rhs, - |a, b| arity_assign::binary(a, b, |a, b| a - b), - |lhs, rhs| lhs - rhs, + |l, r| ArithmeticKernel::wrapping_add(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_add_scalar(r.clone(), l), + |l, r| ArithmeticKernel::wrapping_add_scalar(l.clone(), r), ) } -} -impl Rem for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: Self) -> Self::Output { - (&self).rem(&rhs) + fn wrapping_sub(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_sub(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_sub_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_sub_scalar(l.clone(), r), + ) } -} -// Operands on ChunkedArray & Num + fn wrapping_mul(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_mul(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_mul_scalar(r.clone(), l), + |l, r| ArithmeticKernel::wrapping_mul_scalar(l.clone(), r), + ) + } -impl Add for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_floor_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_floor_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar(l.clone(), r), + ) + } - fn add(self, rhs: N) -> Self::Output { - let adder: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply_values(|val| val + adder); - out.set_sorted_flag(self.is_sorted_flag()); - out + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_trunc_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_trunc_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_trunc_div_scalar(l.clone(), r), + ) } -} -impl Sub for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_mod(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_mod(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_mod_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_mod_scalar(l.clone(), r), + ) + } - fn sub(self, rhs: N) -> Self::Output { - let subber: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply_values(|val| val - subber); - out.set_sorted_flag(self.is_sorted_flag()); - out + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_add_scalar(a.clone(), rhs) + }) } -} -impl Div for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let mut out = self - .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); - - if rhs.tot_lt(&T::Native::zero()) { - out.set_sorted_flag(self.is_sorted_flag().reverse()); - } else { - out.set_sorted_flag(self.is_sorted_flag()); - } - out + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_sub_scalar(a.clone(), rhs) + }) } -} -impl Mul for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_sub_scalar_lhs(lhs, a.clone()) + }) + } - fn mul(self, rhs: N) -> Self::Output { - // don't set sorted flag as probability of overflow is higher - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - let rhs = ChunkedArray::from_vec("", vec![multiplier]); - self.mul(&rhs) + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_mul_scalar(a.clone(), rhs) + }) } -} -impl Rem for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_floor_div_scalar(a.clone(), rhs) + }) + } - fn rem(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let rhs = ChunkedArray::from_vec("", vec![rhs]); - self.rem(&rhs) + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, a.clone()) + }) } -} -impl Add for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar(a.clone(), rhs) + }) + } - fn add(self, rhs: N) -> Self::Output { - (&self).add(rhs) + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar_lhs(lhs, a.clone()) + }) } -} -impl Sub for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_mod_scalar(a.clone(), rhs) + }) + } - fn sub(self, rhs: N) -> Self::Output { - (&self).sub(rhs) + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_mod_scalar_lhs(lhs, a.clone()) + }) } -} -impl Div for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn true_div(self, rhs: Self) -> Self::TrueDivOut { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::true_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::true_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::true_div_scalar(l.clone(), r), + ) + } - fn div(self, rhs: N) -> Self::Output { - (&self).div(rhs) + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut { + unary_kernel(self, |a| ArithmeticKernel::true_div_scalar(a.clone(), rhs)) } -} -impl Mul for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut { + unary_kernel(rhs, |a| { + ArithmeticKernel::true_div_scalar_lhs(lhs, a.clone()) + }) + } - fn mul(mut self, rhs: N) -> Self::Output { - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - self.apply_mut(|val| val * multiplier); - self + fn legacy_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::legacy_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::legacy_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::legacy_div_scalar(l.clone(), r), + ) } -} -impl Rem for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::legacy_div_scalar(a.clone(), rhs) + }) + } - fn rem(self, rhs: N) -> Self::Output { - (&self).rem(rhs) + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::legacy_div_scalar_lhs(lhs, a.clone()) + }) } } diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index f31b2bbf86e8..a9ecbf43ffb8 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -44,7 +44,7 @@ impl ArrayChunked { _ => inner_dtype.clone(), }; - // Safety: + // SAFETY: // inner type passed as physical type let series_container = unsafe { Box::pin(Series::from_chunks_and_dtype_unchecked( @@ -103,6 +103,7 @@ impl ArrayChunked { } /// Apply a closure `F` to each array. + /// /// # Safety /// Return series of `F` must has the same dtype and number of elements as input. #[must_use] @@ -123,6 +124,32 @@ impl ArrayChunked { .collect_ca_with_dtype(self.name(), self.dtype().clone()) } + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + /// + /// # Safety + // Return series of `F` must has the same dtype and number of elements as input series. + #[must_use] + pub unsafe fn zip_and_apply_amortized_same_type<'a, T, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> Self + where + T: PolarsDataType, + F: FnMut(Option>, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + self.amortized_iter() + .zip(ca.iter()) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + out.map(|s| to_arr(&s)) + }) + .collect_ca_with_dtype(self.name(), self.dtype().clone()) + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray diff --git a/crates/polars-core/src/chunked_array/bitwise.rs b/crates/polars-core/src/chunked_array/bitwise.rs index 820bc95adb5c..a47cd9c82aa3 100644 --- a/crates/polars-core/src/chunked_array/bitwise.rs +++ b/crates/polars-core/src/chunked_array/bitwise.rs @@ -1,11 +1,11 @@ use std::ops::{BitAnd, BitOr, BitXor, Not}; use arrow::compute; +use arrow::compute::bitwise; use arrow::compute::utils::combine_validities_and; -use arrow::legacy::compute::bitwise; -use super::arithmetic::arithmetic_helper; use super::*; +use crate::chunked_array::arity::apply_binary_kernel_broadcast; impl BitAnd for &ChunkedArray where @@ -15,7 +15,13 @@ where type Output = ChunkedArray; fn bitand(self, rhs: Self) -> Self::Output { - arithmetic_helper(self, rhs, bitwise::bitand, |a, b| a.bitand(b)) + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::and, + |l, r| bitwise::and_scalar(r, &l), + |l, r| bitwise::and_scalar(l, &r), + ) } } @@ -27,7 +33,13 @@ where type Output = ChunkedArray; fn bitor(self, rhs: Self) -> Self::Output { - arithmetic_helper(self, rhs, bitwise::bitor, |a, b| a.bitor(b)) + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::or, + |l, r| bitwise::or_scalar(r, &l), + |l, r| bitwise::or_scalar(l, &r), + ) } } @@ -39,7 +51,13 @@ where type Output = ChunkedArray; fn bitxor(self, rhs: Self) -> Self::Output { - arithmetic_helper(self, rhs, bitwise::bitxor, |a, b| a.bitxor(b)) + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::xor, + |l, r| bitwise::xor_scalar(r, &l), + |l, r| bitwise::xor_scalar(l, &r), + ) } } diff --git a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs index d2662121c98d..a419ee930401 100644 --- a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs +++ b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs @@ -1,7 +1,3 @@ -use arrow::array::{ - Array, MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, PrimitiveArray, - PushUnchecked, -}; use arrow::types::NativeType; use polars_utils::unwrap::UnwrapUncheckedRelease; use smartstring::alias::String as SmartString; @@ -16,7 +12,8 @@ pub(crate) struct FixedSizeListNumericBuilder { } impl FixedSizeListNumericBuilder { - /// SAFETY + /// # Safety + /// /// The caller must ensure that the physical numerical type match logical type. pub(crate) unsafe fn new( name: &str, diff --git a/crates/polars-core/src/chunked_array/builder/list/binary.rs b/crates/polars-core/src/chunked_array/builder/list/binary.rs index 9c7f9ee6c872..6382d9269f49 100644 --- a/crates/polars-core/src/chunked_array/builder/list/binary.rs +++ b/crates/polars-core/src/chunked_array/builder/list/binary.rs @@ -27,7 +27,7 @@ impl ListStringChunkedBuilder { if iter.size_hint().0 == 0 { self.fast_explode = false; } - // Safety + // SAFETY: // trusted len, trust the type system self.builder.mut_values().extend_trusted_len(iter); self.builder.try_push_valid().unwrap(); @@ -116,7 +116,7 @@ impl ListBinaryChunkedBuilder { if iter.size_hint().0 == 0 { self.fast_explode = false; } - // Safety + // SAFETY: // trusted len, trust the type system self.builder.mut_values().extend_trusted_len(iter); self.builder.try_push_valid().unwrap(); diff --git a/crates/polars-core/src/chunked_array/builder/list/boolean.rs b/crates/polars-core/src/chunked_array/builder/list/boolean.rs index 4d7bc490cb3d..1d83a05ace00 100644 --- a/crates/polars-core/src/chunked_array/builder/list/boolean.rs +++ b/crates/polars-core/src/chunked_array/builder/list/boolean.rs @@ -26,7 +26,7 @@ impl ListBooleanChunkedBuilder { if iter.size_hint().0 == 0 { self.fast_explode = false; } - // Safety + // SAFETY: // trusted len, trust the type system unsafe { values.extend_trusted_len_unchecked(iter) }; self.builder.try_push_valid().unwrap(); diff --git a/crates/polars-core/src/chunked_array/builder/list/categorical.rs b/crates/polars-core/src/chunked_array/builder/list/categorical.rs index 8f9d9599726b..2807991b377d 100644 --- a/crates/polars-core/src/chunked_array/builder/list/categorical.rs +++ b/crates/polars-core/src/chunked_array/builder/list/categorical.rs @@ -142,7 +142,7 @@ impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { let len = self.idx_lookup.len(); // Custom hashing / equality functions for comparing the &str to the idx - // Safety: index in hashmap are within bounds of categories + // SAFETY: index in hashmap are within bounds of categories let r = unsafe { self.idx_lookup.raw_table_mut().find_or_find_insert_slot( hash_cat, @@ -155,13 +155,13 @@ impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { match r { Ok(v) => { - // Safety: Bucket is initialized + // SAFETY: Bucket is initialized idx_mapping.insert_unique_unchecked(idx as u32, unsafe { v.as_ref().0 .0 }); }, Err(e) => { idx_mapping.insert_unique_unchecked(idx as u32, len as u32); self.categories.push(Some(cat)); - // Safety: No mutations in hashmap since find_or_find_insert_slot call + // SAFETY: No mutations in hashmap since find_or_find_insert_slot call unsafe { self.idx_lookup.raw_table_mut().insert_in_slot( hash_cat, @@ -174,7 +174,7 @@ impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { } let op = |opt_v: Option<&u32>| opt_v.map(|v| *idx_mapping.get(v).unwrap()); - // Safety: length is correct as we do one-one mapping over ca. + // SAFETY: length is correct as we do one-one mapping over ca. let iter = unsafe { ca.physical() .downcast_iter() diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs index 0c669490fcd2..db23be277ff0 100644 --- a/crates/polars-core/src/chunked_array/builder/list/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -142,12 +142,15 @@ pub fn get_list_builder( Some(inner_type_logical.clone()), ))), #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => Ok(Box::new(ListPrimitiveChunkedBuilder::::new( - name, - list_capacity, - value_capacity, - inner_type_logical.clone(), - ))), + DataType::Decimal(_, _) => Ok(Box::new( + ListPrimitiveChunkedBuilder::::new_with_values_type( + name, + list_capacity, + value_capacity, + physical_type, + inner_type_logical.clone(), + ), + )), _ => { macro_rules! get_primitive_builder { ($type:ty) => {{ diff --git a/crates/polars-core/src/chunked_array/builder/list/primitive.rs b/crates/polars-core/src/chunked_array/builder/list/primitive.rs index af431e0f7d63..ce9caff3a116 100644 --- a/crates/polars-core/src/chunked_array/builder/list/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/list/primitive.rs @@ -30,6 +30,26 @@ where } } + pub fn new_with_values_type( + name: &str, + capacity: usize, + values_capacity: usize, + values_type: DataType, + logical_type: DataType, + ) -> Self { + let values = MutablePrimitiveArray::::with_capacity_from( + values_capacity, + values_type.to_arrow(true), + ); + let builder = LargePrimitiveBuilder::::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(logical_type))); + Self { + builder, + field, + fast_explode: true, + } + } + #[inline] pub fn append_slice(&mut self, items: &[T::Native]) { let values = self.builder.mut_values(); @@ -58,7 +78,7 @@ where if iter.size_hint().0 == 0 { self.fast_explode = false; } - // Safety + // SAFETY: // trusted len, trust the type system unsafe { values.extend_trusted_len_values_unchecked(iter) }; self.builder.try_push_valid().unwrap(); @@ -72,7 +92,7 @@ where if iter.size_hint().0 == 0 { self.fast_explode = false; } - // Safety + // SAFETY: // trusted len, trust the type system unsafe { values.extend_trusted_len_unchecked(iter) }; self.builder.try_push_valid().unwrap(); @@ -102,7 +122,7 @@ where if !arr.has_validity() { values.extend_from_slice(arr.values().as_slice()) } else { - // Safety: + // SAFETY: // Arrow arrays are trusted length iterators. unsafe { values.extend_trusted_len_unchecked(arr.into_iter()) } } diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index e31fa2968b7c..fec9b92d8681 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -6,7 +6,6 @@ mod null; mod primitive; mod string; -use std::iter::FromIterator; use std::marker::PhantomData; use std::sync::Arc; diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 7cfaafb1db80..24713048cbc4 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -1,10 +1,7 @@ //! Implementations of the ChunkCast Trait. -use std::convert::TryFrom; use arrow::compute::cast::CastOptions; -#[cfg(feature = "dtype-categorical")] -use crate::chunked_array::categorical::CategoricalChunkedBuilder; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; #[cfg(feature = "dtype-datetime")] @@ -91,7 +88,7 @@ where { fn cast_impl(&self, data_type: &DataType, checked: bool) -> PolarsResult { if self.dtype() == data_type { - // safety: chunks are correct dtype + // SAFETY: chunks are correct dtype let mut out = unsafe { Series::from_chunks_and_dtype_unchecked(self.name(), self.chunks.clone(), data_type) }; @@ -105,7 +102,7 @@ where self.dtype() == &DataType::UInt32, ComputeError: "cannot cast numeric types to 'Categorical'" ); - // SAFETY + // SAFETY: // we are guarded by the type system let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; @@ -139,10 +136,10 @@ where polars_bail!(OutOfBounds: "index {} is bigger than the number of categories {}",m,categories.len()); } } - // SAFETY + // SAFETY: // we are guarded by the type system let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; - // SAFETY indices are in bound + // SAFETY: indices are in bound unsafe { Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( ca.clone(), @@ -195,7 +192,7 @@ where DataType::Categorical(Some(rev_map), ordering) | DataType::Enum(Some(rev_map), ordering) => { if self.dtype() == &DataType::UInt32 { - // safety: + // SAFETY: // we are guarded by the type system. let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; Ok(unsafe { @@ -222,7 +219,7 @@ impl ChunkCast for StringChunked { #[cfg(feature = "dtype-categorical")] DataType::Categorical(rev_map, ordering) => match rev_map { None => { - // Safety: length is correct + // SAFETY: length is correct let iter = unsafe { self.downcast_iter().flatten().trust_my_length(self.len()) }; let builder = @@ -360,24 +357,9 @@ impl ChunkCast for BinaryOffsetChunked { } } -fn boolean_to_string(ca: &BooleanChunked) -> StringChunked { - ca.into_iter() - .map(|opt_b| match opt_b { - Some(true) => Some("true"), - Some(false) => Some("false"), - None => None, - }) - .collect() -} - impl ChunkCast for BooleanChunked { fn cast(&self, data_type: &DataType) -> PolarsResult { match data_type { - DataType::String => { - let mut ca = boolean_to_string(self); - ca.rename(self.name()); - Ok(ca.into_series()) - }, #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), _ => cast_impl(self.name(), &self.chunks, data_type), @@ -406,7 +388,7 @@ impl ChunkCast for ListChunked { _ => { // ensure the inner logical type bubbles up let (arr, child_type) = cast_list(self, child_type)?; - // Safety: we just casted so the dtype matches. + // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( @@ -423,7 +405,7 @@ impl ChunkCast for ListChunked { let physical_type = data_type.to_physical(); // cast to the physical type to avoid logical chunks. let chunks = cast_chunks(self.chunks(), &physical_type, true)?; - // Safety: we just casted so the dtype matches. + // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( @@ -468,7 +450,7 @@ impl ChunkCast for ArrayChunked { _ => { // ensure the inner logical type bubbles up let (arr, child_type) = cast_fixed_size_list(self, child_type)?; - // Safety: we just casted so the dtype matches. + // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( @@ -484,7 +466,7 @@ impl ChunkCast for ArrayChunked { let physical_type = data_type.to_physical(); // cast to the physical type to avoid logical chunks. let chunks = cast_chunks(self.chunks(), &physical_type, true)?; - // Safety: we just casted so the dtype matches. + // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( @@ -510,7 +492,7 @@ fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, // TODO!: consider a version that works on chunks and merges the data-types and arrays. let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // safety: inner dtype is passed correctly + // SAFETY: inner dtype is passed correctly let s = unsafe { Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) }; @@ -535,7 +517,7 @@ unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> Polars // TODO! add chunked, but this must correct for list offsets. let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // safety: inner dtype is passed correctly + // SAFETY: inner dtype is passed correctly let s = unsafe { Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) }; @@ -566,7 +548,7 @@ fn cast_fixed_size_list( ) -> PolarsResult<(ArrayRef, DataType)> { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // safety: inner dtype is passed correctly + // SAFETY: inner dtype is passed correctly let s = unsafe { Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) }; diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index 011a9c7a607a..6036d1772b4b 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -49,7 +49,7 @@ where } else { match (lhs.len(), rhs.len()) { (lhs_len, 1) => { - // Safety: physical is in range of revmap + // SAFETY: physical is in range of revmap let v = unsafe { rhs.physical() .get(0) @@ -65,7 +65,7 @@ where .collect_ca_trusted(lhs.name())) }, (1, rhs_len) => { - // Safety: physical is in range of revmap + // SAFETY: physical is in range of revmap let v = unsafe { lhs.physical() .get(0) @@ -367,7 +367,7 @@ where Ok( BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(|opt_idx| { - // Safety: indexing into bitmap with same length as original array + // SAFETY: indexing into bitmap with same length as original array opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) }) })) .with_name(lhs.name()), diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 5d2a335921e4..57fdfc8c0591 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -8,11 +8,11 @@ use std::ops::Not; use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; use arrow::compute; -use arrow::legacy::prelude::FromData; use num_traits::{NumCast, ToPrimitive}; use polars_compute::comparisons::TotalOrdKernel; use crate::prelude::*; +use crate::series::implementations::null::NullChunked; use crate::series::IsSorted; impl ChunkCompare<&ChunkedArray> for ChunkedArray @@ -167,6 +167,52 @@ where } } +impl ChunkCompare<&NullChunked> for NullChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } + + fn equal_missing(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full(self.name(), true, get_broadcast_length(self, rhs)) + } + + fn not_equal(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } + + fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full(self.name(), false, get_broadcast_length(self, rhs)) + } + + fn gt(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } + + fn gt_eq(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } + + fn lt(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } + + fn lt_eq(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + } +} + +#[inline] +fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize { + match (lhs.len(), rhs.len()) { + (1, len_r) => len_r, + (len_l, 1) => len_l, + (len_l, len_r) if len_l == len_r => len_l, + _ => panic!("Cannot compare two series of different lengths."), + } +} + impl ChunkCompare<&BooleanChunked> for BooleanChunked { type Item = BooleanChunked; @@ -864,10 +910,17 @@ impl ChunkEqualElement for ArrayChunked {} mod test { use std::iter::repeat; - use super::super::arithmetic::test::create_two_chunked; use super::super::test::get_chunked_array; use crate::prelude::*; + pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { + let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); + let a2 = Int32Chunked::new("a", &[4, 5, 6]); + let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); + a1.append(&a2); + (a1, a3) + } + #[test] fn test_bitwise_ops() { let a = BooleanChunked::new("a", &[true, false, false]); diff --git a/crates/polars-core/src/chunked_array/drop.rs b/crates/polars-core/src/chunked_array/drop.rs index 1e2ccdd8650c..f92503dda242 100644 --- a/crates/polars-core/src/chunked_array/drop.rs +++ b/crates/polars-core/src/chunked_array/drop.rs @@ -4,7 +4,7 @@ use crate::prelude::*; impl Drop for ChunkedArray { fn drop(&mut self) { if matches!(self.dtype(), DataType::List(_)) { - // Safety + // SAFETY: // guarded by the type system // the transmute only convinces the type system that we are a list // (which we are) diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index d9cc32d40b0b..27781b892876 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -3,10 +3,24 @@ use arrow::array::*; use crate::prelude::*; #[cfg(feature = "dtype-struct")] use crate::series::iterator::SeriesIter; -use crate::utils::CustomIterTools; pub mod par; +impl ChunkedArray +where + T: PolarsDataType, +{ + #[inline] + pub fn iter(&self) -> impl PolarsIterator>> { + // SAFETY: we set the correct length of the iterator. + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.iter()) + .trust_my_length(self.len()) + } + } +} + /// A [`PolarsIterator`] is an iterator over a [`ChunkedArray`] which contains polars types. A [`PolarsIterator`] /// must implement [`ExactSizeIterator`] and [`DoubleEndedIterator`]. pub trait PolarsIterator: @@ -438,7 +452,7 @@ impl<'a> Iterator for StructIter<'a> { for it in &mut self.field_iter { self.buf.push(it.next()?); } - // Safety: + // SAFETY: // Lifetime is bound to struct, we just cannot set the lifetime for the iterator trait unsafe { Some(std::mem::transmute::<&'_ [AnyValue], &'a [AnyValue]>( diff --git a/crates/polars-core/src/chunked_array/iterator/par/list.rs b/crates/polars-core/src/chunked_array/iterator/par/list.rs index 4daf74be3376..59985de275fd 100644 --- a/crates/polars-core/src/chunked_array/iterator/par/list.rs +++ b/crates/polars-core/src/chunked_array/iterator/par/list.rs @@ -16,7 +16,7 @@ impl ListChunked { pub fn par_iter(&self) -> impl ParallelIterator> + '_ { self.chunks.par_iter().flat_map(move |arr| { let dtype = self.inner_dtype(); - // Safety: + // SAFETY: // guarded by the type system let arr = &**arr; let arr = unsafe { &*(arr as *const dyn Array as *const ListArray) }; diff --git a/crates/polars-core/src/chunked_array/iterator/par/string.rs b/crates/polars-core/src/chunked_array/iterator/par/string.rs index 8480b0d32339..6130fb711b4e 100644 --- a/crates/polars-core/src/chunked_array/iterator/par/string.rs +++ b/crates/polars-core/src/chunked_array/iterator/par/string.rs @@ -16,7 +16,7 @@ impl StringChunked { assert_eq!(self.chunks.len(), 1); let arr = &*self.chunks[0]; - // Safety: + // SAFETY: // guarded by the type system let arr = unsafe { &*(arr as *const dyn Array as *const Utf8ViewArray) }; (0..arr.len()) @@ -26,7 +26,7 @@ impl StringChunked { pub fn par_iter(&self) -> impl ParallelIterator> + '_ { self.chunks.par_iter().flat_map(move |arr| { - // Safety: + // SAFETY: // guarded by the type system let arr = &**arr; let arr = unsafe { &*(arr as *const dyn Array as *const Utf8ViewArray) }; diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 82817c12deab..890e156fc3f0 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -4,7 +4,6 @@ use std::ptr::NonNull; use crate::prelude::*; use crate::series::unstable::{ArrayBox, UnstableSeries}; -use crate::utils::CustomIterTools; pub struct AmortizedListIter<'a, I: Iterator>> { len: usize, @@ -46,7 +45,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // structs arrays are bound to the series not to the arrayref // so we must get a hold to the new array if matches!(self.inner_dtype, DataType::Struct(_)) { - // Safety + // SAFETY: // dtype is known unsafe { let mut s = Series::from_chunks_and_dtype_unchecked( @@ -74,7 +73,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // make sure that the length is correct self.series_container._get_inner_mut().compute_len(); - // Safety + // SAFETY: // we cannot control the lifetime of an iterators `next` method. // but as long as self is alive the reference to the series container is valid let refer = &mut *self.series_container; @@ -144,7 +143,7 @@ impl ListChunked { _ => inner_dtype.clone(), }; - // Safety: + // SAFETY: // inner type passed as physical type let series_container = unsafe { let mut s = Series::from_chunks_and_dtype_unchecked( @@ -180,6 +179,17 @@ impl ListChunked { unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } } + pub fn try_apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> PolarsResult> + where + V: PolarsDataType, + F: FnMut(Option>) -> PolarsResult> + Copy, + V::Array: ArrayFromIter>, + { + // TODO! make an amortized iter that does not flatten + // SAFETY: unstable series never lives longer than the iterator. + unsafe { self.amortized_iter().map(f).try_collect_ca(self.name()) } + } + pub fn for_each_amortized<'a, F>(&'a self, f: F) where F: FnMut(Option>), @@ -228,6 +238,54 @@ impl ListChunked { out } + #[must_use] + pub fn binary_zip_and_apply_amortized<'a, T, U, F>( + &'a self, + ca1: &'a ChunkedArray, + ca2: &'a ChunkedArray, + mut f: F, + ) -> Self + where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut( + Option>, + Option>, + Option>, + ) -> Option, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + // SAFETY: unstable series never lives longer than the iterator. + let mut out: ListChunked = unsafe { + self.amortized_iter() + .zip(ca1.iter()) + .zip(ca2.iter()) + .map(|((opt_s, opt_u), opt_v)| { + let out = f(opt_s, opt_u, opt_v); + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Some(out) + }, + None => { + fast_explode = false; + out + }, + } + }) + .collect_trusted() + }; + + out.rename(self.name()); + if fast_explode { + out.set_fast_explode(); + } + out + } + pub fn try_zip_and_apply_amortized<'a, T, I, F>( &'a self, ca: &'a ChunkedArray, diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 5aa0e6ad8618..eba65e8980ab 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -30,6 +30,7 @@ impl ListChunked { } /// Set the logical type of the [`ListChunked`]. + /// /// # Safety /// The caller must ensure that the logical type given fits the physical type of the array. pub unsafe fn to_logical(&mut self, inner_dtype: DataType) { @@ -42,7 +43,7 @@ impl ListChunked { pub fn get_inner(&self) -> Series { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // SAFETY + // SAFETY: // Inner dtype is passed correctly unsafe { Series::from_chunks_and_dtype_unchecked( @@ -62,7 +63,7 @@ impl ListChunked { let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // SAFETY + // SAFETY: // Inner dtype is passed correctly let elements = unsafe { Series::from_chunks_and_dtype_unchecked( @@ -89,7 +90,7 @@ impl ListChunked { arr.validity().cloned(), ); - // safety: arr's inner dtype is derived from out dtype. + // SAFETY: arr's inner dtype is derived from out dtype. Ok(unsafe { ListChunked::from_chunks_and_dtype_unchecked( ca.name(), diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index dc5fd4e48a27..8659a9ff02b7 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -3,7 +3,6 @@ use arrow::legacy::trusted_len::TrustedLenPush; use hashbrown::hash_map::Entry; use polars_utils::iter::EnumerateIdxTrait; -use crate::datatypes::PlHashMap; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::{using_string_cache, StringCache, POOL}; @@ -39,7 +38,7 @@ impl CategoricalChunkedBuilder { let len = self.local_mapping.len() as u32; // Custom hashing / equality functions for comparing the &str to the idx - // Safety: index in hashmap are within bounds of categories + // SAFETY: index in hashmap are within bounds of categories let r = unsafe { self.local_mapping.raw_table_mut().find_or_find_insert_slot( h, @@ -53,12 +52,12 @@ impl CategoricalChunkedBuilder { let idx = match r { Ok(v) => { - // Safety: Bucket is initialized + // SAFETY: Bucket is initialized unsafe { v.as_ref().0 .0 } }, Err(e) => { self.categories.push(Some(s)); - // Safety: No mutations in hashmap since find_or_find_insert_slot call + // SAFETY: No mutations in hashmap since find_or_find_insert_slot call unsafe { self.local_mapping .raw_table_mut() @@ -132,7 +131,7 @@ impl CategoricalChunkedBuilder { let mut local_to_global: Vec = Vec::with_capacity(categories.len()); let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| { for (s, h) in categories.values_iter().zip(hashes) { - // Safety: we allocated enough + // SAFETY: we allocated enough unsafe { local_to_global.push_unchecked(cache.insert_from_hash(h, s)) } } local_to_global @@ -160,7 +159,7 @@ impl CategoricalChunkedBuilder { let indices = std::mem::take(&mut self.cat_builder).into(); let indices = UInt32Chunked::with_chunk(&self.name, indices); - // Safety: indices are in bounds of new rev_map + // SAFETY: indices are in bounds of new rev_map unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( indices, @@ -185,7 +184,7 @@ impl CategoricalChunkedBuilder { } pub fn finish(self) -> CategoricalChunked { - // Safety: keys and values are in bounds + // SAFETY: keys and values are in bounds unsafe { CategoricalChunked::from_keys_and_values( &self.name, @@ -275,7 +274,7 @@ impl CategoricalChunked { // locally we don't need a hashmap because we all categories are 1 integer apart // so the index is local, and the values is global for s in values.values_iter() { - // Safety: we allocated enough + // SAFETY: we allocated enough unsafe { local_to_global.push_unchecked(cache.insert(s)) } } local_to_global @@ -386,7 +385,6 @@ impl CategoricalChunked { #[cfg(test)] mod test { - use crate::chunked_array::categorical::CategoricalChunkedBuilder; use crate::prelude::*; use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index 83a2e96688d7..4be936d79555 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -1,4 +1,3 @@ -use arrow::array::DictionaryArray; use arrow::compute::cast::{cast, utf8view_to_utf8, CastOptions}; use arrow::datatypes::IntegerType; @@ -35,7 +34,7 @@ impl CategoricalChunked { RevMapping::Local(arr, _) => { let values = convert_values(arr, pl_flavor); - // Safety: + // SAFETY: // the keys are in bounds unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() } }, @@ -47,7 +46,7 @@ impl CategoricalChunked { let values = convert_values(values, pl_flavor); - // Safety: + // SAFETY: // the keys are in bounds unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } }, @@ -68,7 +67,7 @@ impl CategoricalChunked { RevMapping::Local(arr, _) => { let values = convert_values(arr, pl_flavor); - // Safety: + // SAFETY: // the keys are in bounds unsafe { DictionaryArray::try_new_unchecked( @@ -92,7 +91,7 @@ impl CategoricalChunked { let values = convert_values(values, pl_flavor); - // Safety: + // SAFETY: // the keys are in bounds unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } }, diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index 8e7114a67076..375f8cc3e72f 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -1,6 +1,8 @@ -use std::sync::Arc; +use std::borrow::Cow; use super::*; +use crate::series::IsSorted; +use crate::utils::align_chunks_binary; fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString { slots.clone().make_mut() @@ -62,7 +64,7 @@ impl GlobalRevMapMerger { for (cat, idx) in map.iter() { state.map.entry(*cat).or_insert_with(|| { - // Safety + // SAFETY: // within bounds let str_val = unsafe { slots.value_unchecked(*idx as usize) }; let new_idx = state.slots.len() as u32; @@ -174,7 +176,7 @@ pub fn call_categorical_merge_operation( }, _ => polars_bail!(string_cache_mismatch), }; - // Safety: physical and rev map are correctly constructed above + // SAFETY: physical and rev map are correctly constructed above unsafe { Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( new_physical, @@ -201,7 +203,7 @@ pub fn make_categoricals_compatible( // Alter rev map of left let mut new_ca_left = ca_left.clone(); - // Safety: We just made both rev maps compatible only appended categories + // SAFETY: We just made both rev maps compatible only appended categories unsafe { new_ca_left.set_rev_map( new_ca_right.get_rev_map().clone(), @@ -211,3 +213,42 @@ pub fn make_categoricals_compatible( Ok((new_ca_left, new_ca_right)) } + +pub fn make_list_categoricals_compatible( + mut list_ca_left: ListChunked, + list_ca_right: ListChunked, +) -> PolarsResult<(ListChunked, ListChunked)> { + // Make categoricals compatible + + let cat_left = list_ca_left.get_inner(); + let cat_right = list_ca_right.get_inner(); + let (cat_left, cat_right) = + make_categoricals_compatible(cat_left.categorical()?, cat_right.categorical()?)?; + + // we only appended categories to the rev_map at the end, so only change the inner dtype + list_ca_left.set_inner_dtype(cat_left.dtype().clone()); + + // We changed the physicals and the rev_map, offsets and validity buffers are still good + let (list_ca_right, cat_physical): (Cow, Cow) = + align_chunks_binary(&list_ca_right, cat_right.physical()); + let mut list_ca_right = list_ca_right.into_owned(); + // SAFETY: + // Chunks are aligned, length / dtype remains correct + unsafe { + list_ca_right + .downcast_iter_mut() + .zip(cat_physical.chunks()) + .for_each(|(arr, new_phys)| { + *arr = ListArray::new( + arr.data_type().clone(), + arr.offsets().clone(), + new_phys.clone(), + arr.validity().cloned(), + ) + }); + } + // reset the sorted flag and add extra categories back in + list_ca_right.set_sorted_flag(IsSorted::Not); + list_ca_right.set_inner_dtype(cat_right.dtype().clone()); + Ok((list_ca_left, list_ca_right)) +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index ba282994de65..15fa1e994453 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -105,7 +105,7 @@ impl CategoricalChunked { self.get_ordering(), ) }; - out.set_fast_unique(self.can_fast_unique()); + out.set_fast_unique(self._can_fast_unique()); out } @@ -118,7 +118,7 @@ impl CategoricalChunked { RevMapping::Local(categories, _) => categories, }; - // Safety: keys and values are in bounds + // SAFETY: keys and values are in bounds unsafe { Ok(CategoricalChunked::from_keys_and_values_global( self.name(), @@ -177,7 +177,7 @@ impl CategoricalChunked { .collect::>()?; Ok( - // Safety: we created the physical from the enum categories + // SAFETY: we created the physical from the enum categories unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( new_phys, @@ -274,7 +274,9 @@ impl CategoricalChunked { } } - pub(crate) fn can_fast_unique(&self) -> bool { + /// True if all categories are represented in this array. When this is the case, the unique + /// values of the array are the categories. + pub fn _can_fast_unique(&self) -> bool { self.bit_settings.contains(BitSettings::ORIGINAL) && self.physical.chunks.len() == 1 && self.null_count() == 0 @@ -373,7 +375,8 @@ impl LogicalType for CategoricalChunked { Ok(self .to_enum(categories, *hash)? .set_ordering(*ordering, true) - .into_series()) + .into_series() + .with_name(self.name())) }, DataType::Enum(None, _) => { polars_bail!(ComputeError: "can not cast to enum without categories present") @@ -395,9 +398,11 @@ impl LogicalType for CategoricalChunked { Ok(self.clone().set_ordering(*ordering, true).into_series()) }, dt if dt.is_numeric() => { - // Apply the cast to to the categories and then index into the casted series - let categories = - StringChunked::with_chunk("", self.get_rev_map().get_categories().clone()); + // Apply the cast to the categories and then index into the casted series + let categories = StringChunked::with_chunk( + self.physical.name(), + self.get_rev_map().get_categories().clone(), + ); let casted_series = categories.cast(dtype)?; #[cfg(feature = "bigidx")] @@ -407,7 +412,7 @@ impl LogicalType for CategoricalChunked { } #[cfg(not(feature = "bigidx"))] { - // Safety: Invariant of categorical means indices are in bound + // SAFETY: Invariant of categorical means indices are in bound Ok(unsafe { casted_series.take_unchecked(&self.physical) }) } }, @@ -429,7 +434,7 @@ impl<'a> Iterator for CatIter<'a> { fn next(&mut self) -> Option { self.iter.next().map(|item| { item.map(|idx| { - // Safety: + // SAFETY: // all categories are in bound unsafe { self.rev.get_unchecked(idx) } }) @@ -445,8 +450,6 @@ impl<'a> ExactSizeIterator for CatIter<'a> {} #[cfg(test)] mod test { - use std::convert::TryFrom; - use super::*; use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index b73f4c8d9a38..a5477fccb376 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -1,10 +1,9 @@ use super::*; -use crate::frame::group_by::IntoGroupsProxy; impl CategoricalChunked { pub fn unique(&self) -> PolarsResult { let cat_map = self.get_rev_map(); - if self.can_fast_unique() { + if self._can_fast_unique() { let ca = match &**cat_map { RevMapping::Local(a, _) => { UInt32Chunked::from_iter_values(self.physical().name(), 0..(a.len() as u32)) @@ -13,7 +12,7 @@ impl CategoricalChunked { UInt32Chunked::from_iter_values(self.physical().name(), map.keys().copied()) }, }; - // safety: + // SAFETY: // we only removed some indexes so we are still in bounds unsafe { let mut out = CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -27,7 +26,7 @@ impl CategoricalChunked { } } else { let ca = self.physical().unique()?; - // safety: + // SAFETY: // we only removed some indexes so we are still in bounds unsafe { Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -41,7 +40,7 @@ impl CategoricalChunked { } pub fn n_unique(&self) -> PolarsResult { - if self.can_fast_unique() { + if self._can_fast_unique() { Ok(self.get_rev_map().len()) } else { self.physical().n_unique() @@ -66,7 +65,7 @@ impl CategoricalChunked { let mut counts = groups.group_count(); counts.rename("counts"); let cols = vec![values.into_series(), counts.into_series()]; - let df = DataFrame::new_no_checks(cols); + let df = unsafe { DataFrame::new_no_checks(cols) }; df.sort(["counts"], true, false) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs index ae9e52543494..7f8a97e9406c 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs @@ -158,14 +158,14 @@ impl RevMapping { } rev_map .iter() - // Safety: + // SAFETY: // value is always within bounds .find(|(_k, &v)| (unsafe { a.value_unchecked(v as usize) } == value)) .map(|(k, _v)| *k) }, Self::Local(a, _) => { - // Safety: within bounds + // SAFETY: within bounds unsafe { (0..a.len()).find(|idx| a.value_unchecked(*idx) == value) } .map(|idx| idx as u32) }, diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 38b5593a92c2..4585d89b4af9 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -29,9 +29,9 @@ impl LogicalType for DateChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; - match (self.dtype(), dtype) { + match dtype { #[cfg(feature = "dtype-datetime")] - (Date, Datetime(tu, tz)) => { + Datetime(tu, tz) => { let casted = self.0.cast(dtype)?; let casted = casted.datetime().unwrap(); let conversion = match tu { @@ -44,9 +44,9 @@ impl LogicalType for DateChunked { .into_series()) }, #[cfg(feature = "dtype-time")] - (Date, Time) => Ok(Int64Chunked::full(self.name(), 0i64, self.len()) - .into_time() - .into_series()), + Time => { + polars_bail!(ComputeError: "cannot cast `Date` to `Time`"); + }, _ => self.0.cast(dtype), } } diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs index be53ccd19c93..337d18357f58 100644 --- a/crates/polars-core/src/chunked_array/logical/datetime.rs +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -31,67 +31,77 @@ impl LogicalType for DatetimeChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; match (self.dtype(), dtype) { - (Datetime(TimeUnit::Milliseconds, _), Datetime(TimeUnit::Nanoseconds, tz)) => { - Ok((self.0.as_ref() * 1_000_000i64) - .into_datetime(TimeUnit::Nanoseconds, tz.clone()) - .into_series()) - }, - (Datetime(TimeUnit::Milliseconds, _), Datetime(TimeUnit::Microseconds, tz)) => { - Ok((self.0.as_ref() * 1_000i64) - .into_datetime(TimeUnit::Microseconds, tz.clone()) - .into_series()) - }, - (Datetime(TimeUnit::Nanoseconds, _), Datetime(TimeUnit::Milliseconds, tz)) => { - Ok((self.0.as_ref() / 1_000_000i64) - .into_datetime(TimeUnit::Milliseconds, tz.clone()) - .into_series()) - }, - (Datetime(TimeUnit::Nanoseconds, _), Datetime(TimeUnit::Microseconds, tz)) => { - Ok((self.0.as_ref() / 1_000i64) - .into_datetime(TimeUnit::Microseconds, tz.clone()) - .into_series()) - }, - (Datetime(TimeUnit::Microseconds, _), Datetime(TimeUnit::Milliseconds, tz)) => { - Ok((self.0.as_ref() / 1_000i64) - .into_datetime(TimeUnit::Milliseconds, tz.clone()) - .into_series()) - }, - (Datetime(TimeUnit::Microseconds, _), Datetime(TimeUnit::Nanoseconds, tz)) => { - Ok((self.0.as_ref() * 1_000i64) - .into_datetime(TimeUnit::Nanoseconds, tz.clone()) - .into_series()) + (Datetime(from_unit, _), Datetime(to_unit, tz)) => { + let (multiplier, divisor) = match (from_unit, to_unit) { + // scaling from lower precision to higher precision + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => (Some(1_000_000i64), None), + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => (Some(1_000i64), None), + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => (Some(1_000i64), None), + // scaling from higher precision to lower precision + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => (None, Some(1_000_000i64)), + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => (None, Some(1_000i64)), + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => (None, Some(1_000i64)), + _ => return self.0.cast(dtype), + }; + let result = match multiplier { + // scale to higher precision (eg: ms → us, ms → ns, us → ns) + Some(m) => Ok((self.0.as_ref() * m) + .into_datetime(*to_unit, tz.clone()) + .into_series()), + // scale to lower precision (eg: ns → us, ns → ms, us → ms) + None => match divisor { + Some(d) => Ok(self + .0 + .apply_values(|v| v.div_euclid(d)) + .into_datetime(*to_unit, tz.clone()) + .into_series()), + None => unreachable!("must always have a time unit divisor here"), + }, + }; + result }, #[cfg(feature = "dtype-date")] - (Datetime(tu, _), Date) => match tu { - TimeUnit::Nanoseconds => Ok((self.0.as_ref() / NS_IN_DAY) - .cast(&Int32) - .unwrap() - .into_date() - .into_series()), - TimeUnit::Microseconds => Ok((self.0.as_ref() / US_IN_DAY) - .cast(&Int32) - .unwrap() - .into_date() - .into_series()), - TimeUnit::Milliseconds => Ok((self.0.as_ref() / MS_IN_DAY) - .cast(&Int32) - .unwrap() - .into_date() - .into_series()), + (Datetime(tu, _), Date) => { + let cast_to_date = |tu_in_day: i64| { + let mut dt = self + .0 + .apply_values(|v| v.div_euclid(tu_in_day)) + .cast(&Int32) + .unwrap() + .into_date() + .into_series(); + dt.set_sorted_flag(self.is_sorted_flag()); + Ok(dt) + }; + match tu { + TimeUnit::Nanoseconds => cast_to_date(NS_IN_DAY), + TimeUnit::Microseconds => cast_to_date(US_IN_DAY), + TimeUnit::Milliseconds => cast_to_date(MS_IN_DAY), + } }, #[cfg(feature = "dtype-time")] - (Datetime(tu, _), Time) => Ok({ - let (modder, multiplier) = match tu { + (Datetime(tu, _), Time) => { + let (scaled_mod, multiplier) = match tu { TimeUnit::Nanoseconds => (NS_IN_DAY, 1i64), TimeUnit::Microseconds => (US_IN_DAY, 1_000i64), TimeUnit::Milliseconds => (MS_IN_DAY, 1_000_000i64), }; - self.0 - .apply_values(|v| (v % modder * multiplier) + (NS_IN_DAY * (v < 0) as i64)) + return Ok(self + .0 + .apply_values(|v| { + let t = v % scaled_mod * multiplier; + t + (NS_IN_DAY * (t < 0) as i64) + }) .into_time() - .into_series() - }), - _ => self.0.cast(dtype), + .into_series()); + }, + _ => return self.0.cast(dtype), } + .map(|mut s| { + // TODO!; implement the divisions/multipliers above + // in a checked manner so that we raise on overflow + s.set_sorted_flag(self.is_sorted_flag()); + s + }) } } diff --git a/crates/polars-core/src/chunked_array/logical/decimal.rs b/crates/polars-core/src/chunked_array/logical/decimal.rs index 649f0e6b66f2..d21d3909b415 100644 --- a/crates/polars-core/src/chunked_array/logical/decimal.rs +++ b/crates/polars-core/src/chunked_array/logical/decimal.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use super::*; use crate::chunked_array::cast::cast_chunks; use crate::prelude::*; @@ -38,16 +40,14 @@ impl Int128Chunked { } pub fn into_decimal( - mut self, + self, precision: Option, scale: usize, ) -> PolarsResult { - self.update_chunks_dtype(precision, scale); // TODO: if precision is None, do we check that the value fits within precision of 38?... if let Some(precision) = precision { let precision_max = 10_i128.pow(precision as u32); - // note: this is not too efficient as it scans through the data twice... - if let (Some(min), Some(max)) = (self.min(), self.max()) { + if let Some((min, max)) = self.min_max() { let max_abs = max.abs().max(min.abs()); polars_ensure!( max_abs < precision_max, @@ -83,9 +83,7 @@ impl LogicalType for DecimalChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { let (precision_src, scale_src) = (self.precision(), self.scale()); if let &DataType::Decimal(precision_dst, scale_dst) = dtype { - let scale_dst = scale_dst.ok_or_else( - || polars_err!(ComputeError: "cannot cast to Decimal with unknown scale"), - )?; + let scale_dst = scale_dst.unwrap_or(scale_src); // for now, let's just allow same-scale conversions // where precision is either the same or bigger or gets converted to `None` // (these are the easy cases requiring no checks and arithmetics which we can add later) @@ -95,6 +93,7 @@ impl LogicalType for DecimalChunked { _ => false, }; if scale_src == scale_dst && is_widen { + let dtype = &DataType::Decimal(precision_dst, Some(scale_dst)); return self.0.cast(dtype); // no conversion or checks needed } } @@ -123,4 +122,16 @@ impl DecimalChunked { _ => unreachable!(), } } + + pub(crate) fn to_scale(&self, scale: usize) -> PolarsResult> { + if self.scale() == scale { + return Ok(Cow::Borrowed(self)); + } + + let dtype = DataType::Decimal(None, Some(scale)); + let chunks = cast_chunks(&self.chunks, &dtype, true)?; + let mut dt = Self::new_logical(unsafe { Int128Chunked::from_chunks(self.name(), chunks) }); + dt.2 = Some(dtype); + Ok(Cow::Owned(dt)) + } } diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs index f6a010810ccc..64ef1620c3c0 100644 --- a/crates/polars-core/src/chunked_array/logical/duration.rs +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -41,7 +41,7 @@ impl LogicalType for DurationChunked { .into_series()) }, (Duration(TimeUnit::Microseconds), Duration(TimeUnit::Milliseconds)) => { - Ok((self.0.as_ref() / 1_000i64) + Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)) .into_duration(TimeUnit::Milliseconds) .into_series()) }, @@ -51,12 +51,12 @@ impl LogicalType for DurationChunked { .into_series()) }, (Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Milliseconds)) => { - Ok((self.0.as_ref() / 1_000_000i64) + Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64)) .into_duration(TimeUnit::Milliseconds) .into_series()) }, (Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Microseconds)) => { - Ok((self.0.as_ref() / 1_000i64) + Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)) .into_duration(TimeUnit::Microseconds) .into_series()) }, diff --git a/crates/polars-core/src/chunked_array/logical/struct_/from.rs b/crates/polars-core/src/chunked_array/logical/struct_/from.rs index 23e23410b3f0..4ec570767282 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/from.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/from.rs @@ -4,11 +4,11 @@ impl From for DataFrame { fn from(ca: StructChunked) -> Self { #[cfg(feature = "object")] { - DataFrame::new_no_checks(ca.fields.clone()) + unsafe { DataFrame::new_no_checks(ca.fields.clone()) } } #[cfg(not(feature = "object"))] { - DataFrame::new_no_checks(ca.fields) + unsafe { DataFrame::new_no_checks(ca.fields) } } } } diff --git a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs index eaad94169ecd..9457bfe00bd1 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -9,7 +9,9 @@ use arrow::legacy::trusted_len::TrustedLenPush; use arrow::offset::OffsetsBuffer; use smartstring::alias::String as SmartString; +use self::sort::arg_sort_multiple::_get_rows_encoded_ca; use super::*; +use crate::chunked_array::iterator::StructIter; use crate::datatypes::*; use crate::utils::index_to_chunked_index; @@ -112,7 +114,7 @@ impl StructChunked { } Ok(Self::new_unchecked(name, &new_fields)) } else if fields.is_empty() { - let fields = &[Series::full_null("", 0, &DataType::Null)]; + let fields = &[Series::new_null("", 0)]; Ok(Self::new_unchecked(name, fields)) } else { Ok(Self::new_unchecked(name, fields)) @@ -282,7 +284,7 @@ impl StructChunked { Ok(Self::new_unchecked(self.field.name(), &fields)) } - pub(crate) fn apply_fields(&self, func: F) -> Self + pub fn _apply_fields(&self, func: F) -> Self where F: FnMut(&Series) -> Series, { @@ -347,7 +349,7 @@ impl StructChunked { let mut length_so_far = 0_i64; unsafe { - // safety: we have pre-allocated + // SAFETY: we have pre-allocated offsets.push_unchecked(length_so_far); } for row in 0..ca.len() { @@ -364,7 +366,7 @@ impl StructChunked { unsafe { *values.last_mut().unwrap_unchecked() = b'}'; - // safety: we have pre-allocated + // SAFETY: we have pre-allocated length_so_far = values.len() as i64; offsets.push_unchecked(length_so_far); } @@ -411,6 +413,15 @@ impl StructChunked { } self.cast_impl(dtype, true) } + + pub fn rows_encode(&self) -> PolarsResult { + let descending = vec![false; self.fields.len()]; + _get_rows_encoded_ca(self.name(), &self.fields, &descending, false) + } + + pub fn iter(&self) -> StructIter { + self.into_iter() + } } impl LogicalType for StructChunked { @@ -427,7 +438,7 @@ impl LogicalType for StructChunked { unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { let (chunk_idx, idx) = index_to_chunked_index(self.chunks.iter().map(|c| c.len()), i); if let DataType::Struct(flds) = self.dtype() { - // safety: we already have a single chunk and we are + // SAFETY: we already have a single chunk and we are // guarded by the type system. unsafe { let arr = &**self.chunks.get_unchecked(chunk_idx); diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs index 8710e1c12426..8dd4c6239ae9 100644 --- a/crates/polars-core/src/chunked_array/logical/time.rs +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -29,8 +29,9 @@ impl LogicalType for TimeChunked { } fn cast(&self, dtype: &DataType) -> PolarsResult { + use DataType::*; match dtype { - DataType::Duration(tu) => { + Duration(tu) => { let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds)); if !matches!(tu, TimeUnit::Nanoseconds) { out?.cast(dtype) @@ -38,6 +39,14 @@ impl LogicalType for TimeChunked { out } }, + #[cfg(feature = "dtype-date")] + Date => { + polars_bail!(ComputeError: "cannot cast `Time` to `Date`"); + }, + #[cfg(feature = "dtype-datetime")] + Datetime(_, _) => { + polars_bail!(ComputeError: "cannot cast `Time` to `Datetime`; consider using `dt.combine`"); + }, _ => self.0.cast(dtype), } } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 98e8c5c96b10..7e8cc98a11f0 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -52,7 +52,7 @@ use arrow::legacy::prelude::*; use bitflags::bitflags; use crate::series::IsSorted; -use crate::utils::{first_non_null, last_non_null, CustomIterTools}; +use crate::utils::{first_non_null, last_non_null}; #[cfg(not(feature = "dtype-categorical"))] pub struct RevMapping {} @@ -96,12 +96,12 @@ pub type ChunkIdIter<'a> = std::iter::Map, fn(&Ar /// # use polars_core::prelude::*; /// /// fn iter_forward(ca: &Float32Chunked) { -/// ca.into_iter() +/// ca.iter() /// .for_each(|opt_v| println!("{:?}", opt_v)) /// } /// /// fn iter_backward(ca: &Float32Chunked) { -/// ca.into_iter() +/// ca.iter() /// .rev() /// .for_each(|opt_v| println!("{:?}", opt_v)) /// } @@ -211,6 +211,13 @@ impl ChunkedArray { self.bit_settings.set_sorted_flag(sorted) } + /// Set the 'sorted' bit meta info. + pub fn with_sorted_flag(&self, sorted: IsSorted) -> Self { + let mut out = self.clone(); + out.bit_settings.set_sorted_flag(sorted); + out + } + /// Get the index of the first non null value in this [`ChunkedArray`]. pub fn first_non_null(&self) -> Option { if self.is_empty() { @@ -296,7 +303,7 @@ impl ChunkedArray { series.dtype(), self.dtype(), ); - // Safety + // SAFETY: // dtype will be correct. Ok(unsafe { self.unpack_series_matching_physical_type(series) }) } @@ -574,7 +581,7 @@ where /// Contiguous mutable slice pub(crate) fn cont_slice_mut(&mut self) -> Option<&mut [T::Native]> { if self.chunks.len() == 1 && self.chunks[0].null_count() == 0 { - // Safety, we will not swap the PrimitiveArray. + // SAFETY, we will not swap the PrimitiveArray. let arr = unsafe { self.downcast_iter_mut().next().unwrap() }; arr.get_mut_values() } else { @@ -759,10 +766,7 @@ pub(crate) mod test { where T: PolarsNumericType, { - assert_eq!( - ca.into_iter().map(|opt| opt.unwrap()).collect::>(), - eq - ) + assert_eq!(ca.iter().map(|opt| opt.unwrap()).collect::>(), eq) } #[test] diff --git a/crates/polars-core/src/chunked_array/ndarray.rs b/crates/polars-core/src/chunked_array/ndarray.rs index bc043da81e02..77dc7d3d5ceb 100644 --- a/crates/polars-core/src/chunked_array/ndarray.rs +++ b/crates/polars-core/src/chunked_array/ndarray.rs @@ -68,7 +68,7 @@ impl ListChunked { } debug_assert_eq!(row_idx, self.len()); - // Safety: + // SAFETY: // We have assigned to every row and element of the array unsafe { Ok(ndarray.assume_init()) } } @@ -142,7 +142,7 @@ impl DataFrame { let vals = ca.cont_slice().unwrap(); // Depending on the desired order, we add items to the buffer. - // Safety: + // SAFETY: // We get parallel access to the vector by offsetting index access accordingly. // For C-order, we only operate on every num-col-th element, starting from the // column index. For Fortran-order we only operate on n contiguous elements, @@ -158,7 +158,7 @@ impl DataFrame { }, IndexOrder::Fortran => unsafe { let offset_ptr = (ptr as *mut N::Native).add(col_idx * height); - // Safety: + // SAFETY: // this is uninitialized memory, so we must never read from this data // copy_from_slice does not read let buf = std::slice::from_raw_parts_mut(offset_ptr, height); @@ -170,7 +170,7 @@ impl DataFrame { .collect::>>() })?; - // Safety: + // SAFETY: // we have written all data, so we can now safely set length unsafe { membuf.set_len(shape.0 * shape.1); diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 00827c6e14f8..1bf6727342a9 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -1,5 +1,4 @@ use std::marker::PhantomData; -use std::sync::Arc; use arrow::bitmap::MutableBitmap; @@ -194,7 +193,7 @@ pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef { // The list builder knows how to create an arrow array // we simply piggy back on that code. - // safety: 0..len is in bounds + // SAFETY: 0..len is in bounds let list_s = unsafe { s.agg_list(&GroupsProxy::Slice { groups: vec![[0, s.len() as IdxSize]], diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 7b0edc76ba5f..df5797769b91 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod polars_extension; use std::mem; -use arrow::array::{Array, FixedSizeBinaryArray}; +use arrow::array::FixedSizeBinaryArray; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; use polars_extension::PolarsExtension; @@ -76,7 +76,7 @@ pub(crate) fn create_extension> + TrustedLen, T: Si Some(t) => { unsafe { buf.extend_from_slice(any_as_u8_slice(&t)); - // Safety: we allocated upfront + // SAFETY: we allocated upfront validity.push_unchecked(true) } mem::forget(t); @@ -85,7 +85,7 @@ pub(crate) fn create_extension> + TrustedLen, T: Si null_count += 1; unsafe { buf.extend_from_slice(any_as_u8_slice(&T::default())); - // Safety: we allocated upfront + // SAFETY: we allocated upfront validity.push_unchecked(false) } }, @@ -101,7 +101,7 @@ pub(crate) fn create_extension> + TrustedLen, T: Si // ptr to start of T, not to start of padding let ptr = buf.as_slice().as_ptr(); - // Safety: + // SAFETY: // ptr and t are correct let drop_fn = unsafe { create_drop::(ptr, n_t_vals) }; let et = Box::new(ExtensionSentinel { @@ -125,7 +125,7 @@ pub(crate) fn create_extension> + TrustedLen, T: Si let array = FixedSizeBinaryArray::new(extension_type, buf, validity); - // Safety: + // SAFETY: // we just heap allocated the ExtensionSentinel, so its alive. unsafe { PolarsExtension::new(array) } } @@ -133,8 +133,10 @@ pub(crate) fn create_extension> + TrustedLen, T: Si #[cfg(test)] mod test { use std::fmt::{Display, Formatter}; + use std::hash::{Hash, Hasher}; - use polars_utils::idxvec; + use polars_utils::total_ord::TotalHash; + use polars_utils::unitvec; use super::*; @@ -151,6 +153,15 @@ mod test { } } + impl TotalHash for Foo { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + impl Display for Foo { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) @@ -200,7 +211,7 @@ mod test { let ca = ObjectChunked::new("", values); let groups = - GroupsProxy::Idx(vec![(0, idxvec![0, 1]), (2, idxvec![2]), (3, idxvec![3])].into()); + GroupsProxy::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into()); let out = unsafe { ca.agg_list(&groups) }; assert!(matches!(out.dtype(), DataType::List(_))); assert_eq!(out.len(), groups.len()); @@ -223,7 +234,7 @@ mod test { let values = &[Some(foo1.clone()), None, Some(foo2.clone()), None]; let ca = ObjectChunked::new("", values); - let groups = vec![(0, idxvec![0, 1]), (2, idxvec![2]), (3, idxvec![3])].into(); + let groups = vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into(); let out = unsafe { ca.agg_list(&GroupsProxy::Idx(groups)) }; let a = out.explode().unwrap(); diff --git a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs index 4cd2de7abad9..6030f668dfe1 100644 --- a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs +++ b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs @@ -1,7 +1,5 @@ use std::mem::ManuallyDrop; -use arrow::array::FixedSizeBinaryArray; - use super::*; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/object/iterator.rs b/crates/polars-core/src/chunked_array/object/iterator.rs index 5433f048be46..7a5c6e00b590 100644 --- a/crates/polars-core/src/chunked_array/object/iterator.rs +++ b/crates/polars-core/src/chunked_array/object/iterator.rs @@ -29,7 +29,7 @@ impl<'a, T: PolarsObject> std::iter::Iterator for ObjectIter<'a, T> { fn next(&mut self) -> Option { if self.current == self.current_end { None - // Safety: + // SAFETY: // Se comment below } else if unsafe { self.array.is_null_unchecked(self.current) } { self.current += 1; @@ -37,7 +37,7 @@ impl<'a, T: PolarsObject> std::iter::Iterator for ObjectIter<'a, T> { } else { let old = self.current; self.current += 1; - // Safety: + // SAFETY: // we just checked bounds in `self.current_end == self.current` // this is safe on the premise that this struct is initialized with // current = array.len() @@ -63,7 +63,7 @@ impl<'a, T: PolarsObject> std::iter::DoubleEndedIterator for ObjectIter<'a, T> { Some(if self.array.is_null(self.current_end) { None } else { - // Safety: + // SAFETY: // we just checked bounds in `self.current_end == self.current` // this is safe on the premise that this struct is initialized with // current = array.len() @@ -118,7 +118,7 @@ impl std::iter::Iterator for OwnedObjectIter { fn next(&mut self) -> Option { if self.current == self.current_end { None - // Safety: + // SAFETY: // Se comment below } else if unsafe { self.array.is_null_unchecked(self.current) } { self.current += 1; @@ -126,7 +126,7 @@ impl std::iter::Iterator for OwnedObjectIter { } else { let old = self.current; self.current += 1; - // Safety: + // SAFETY: // we just checked bounds in `self.current_end == self.current` // this is safe on the premise that this struct is initialized with // current = array.len() diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index f51eb63c3fca..9f17a1d1b434 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -4,6 +4,7 @@ use std::hash::Hash; use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use arrow::bitmap::Bitmap; +use polars_utils::total_ord::TotalHash; use crate::prelude::*; @@ -36,7 +37,7 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { /// Values need to implement this so that they can be stored into a Series and DataFrame pub trait PolarsObject: - Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + TotalEq + Any + Debug + Clone + Send + Sync + Default + Display + Hash + TotalHash + PartialEq + Eq + TotalEq { /// This should be used as type information. Consider this a part of the type system. fn type_name() -> &'static str; @@ -121,15 +122,6 @@ where !self.is_valid_unchecked(i) } - #[inline] - pub(crate) unsafe fn get_unchecked(&self, item: usize) -> Option<&T> { - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } - /// Returns this array with a new validity. /// # Panic /// Panics iff `validity.len() != self.len()`. @@ -217,11 +209,19 @@ where /// /// No bounds checks pub unsafe fn get_object_unchecked(&self, index: usize) -> Option<&dyn PolarsObjectSafe> { - let chunks = self.downcast_chunks(); let (chunk_idx, idx) = self.index_to_chunked_index(index); - let arr = chunks.get_unchecked(chunk_idx); - if arr.is_valid_unchecked(idx) { - Some(arr.value(idx)) + self.get_object_chunked_unchecked(chunk_idx, idx) + } + + pub(crate) unsafe fn get_object_chunked_unchecked( + &self, + chunk: usize, + index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + let chunks = self.downcast_chunks(); + let arr = chunks.get_unchecked(chunk); + if arr.is_valid_unchecked(index) { + Some(arr.value(index)) } else { None } diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index e34c4b77041b..5ebcad2a022a 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -14,6 +14,7 @@ use crate::datatypes::AnyValue; use crate::prelude::PolarsObject; use crate::series::{IntoSeries, Series}; +/// Takes a `name` and `capacity` and constructs a new builder. pub type BuilderConstructor = Box Box + Send + Sync>; pub type ObjectConverter = Arc Box + Send + Sync>; @@ -58,6 +59,13 @@ pub trait AnonymousObjectBuilder { /// [ObjectChunked]: crate::chunked_array::object::ObjectChunked fn append_value(&mut self, value: &dyn Any); + fn append_option(&mut self, value: Option<&dyn Any>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + } + } + /// Take the current state and materialize as a [`Series`] /// the builder should not be used after that. fn to_series(&mut self) -> Series; diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 78825e40bfe4..a67627ad6a04 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -142,6 +142,45 @@ where } } + fn min_max(&self) -> Option<(T::Native, T::Native)> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + let min = self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }); + let max = self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }); + min.zip(max) + }, + IsSorted::Descending => { + let max = self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }); + let min = self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }); + min.zip(max) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::min_max_ignore_nan_kernel) + .reduce(|(min1, max1), (min2, max2)| { + ( + MinMax::min_ignore_nan(min1, min2), + MinMax::max_ignore_nan(max1, max2), + ) + }), + } + } + fn mean(&self) -> Option { if self.is_empty() || self.null_count() == self.len() { return None; @@ -175,7 +214,7 @@ where for arr in self.downcast_iter() { if arr.null_count() > 0 { for v in arr.into_iter().flatten() { - // safety + // SAFETY: // all these types can be coerced to f64 unsafe { let val = v.to_f64().unwrap_unchecked(); @@ -184,7 +223,7 @@ where } } else { for v in arr.values().as_slice() { - // safety + // SAFETY: // all these types can be coerced to f64 unsafe { let val = v.to_f64().unwrap_unchecked(); @@ -475,6 +514,75 @@ impl ChunkAggSeries for StringChunked { } } +#[cfg(feature = "dtype-categorical")] +impl CategoricalChunked { + fn min_categorical(&self) -> Option<&str> { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + if self.uses_lexical_ordering() { + // Fast path where all categories are used + if self._can_fast_unique() { + self.get_rev_map().get_categories().min_ignore_nan_kernel() + } else { + let rev_map = self.get_rev_map(); + // SAFETY: + // Indices are in bounds + self.physical() + .iter() + .flat_map(|opt_el: Option| { + opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) + }) + .min() + } + } else { + // SAFETY: + // Indices are in bounds + self.physical() + .min() + .map(|el| unsafe { self.get_rev_map().get_unchecked(el) }) + } + } + + fn max_categorical(&self) -> Option<&str> { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + if self.uses_lexical_ordering() { + // Fast path where all categories are used + if self._can_fast_unique() { + self.get_rev_map().get_categories().max_ignore_nan_kernel() + } else { + let rev_map = self.get_rev_map(); + // SAFETY: + // Indices are in bounds + self.physical() + .iter() + .flat_map(|opt_el: Option| { + opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) + }) + .max() + } + } else { + // SAFETY: + // Indices are in bounds + self.physical() + .max() + .map(|el| unsafe { self.get_rev_map().get_unchecked(el) }) + } + } +} + +#[cfg(feature = "dtype-categorical")] +impl ChunkAggSeries for CategoricalChunked { + fn min_as_series(&self) -> Series { + Series::new(self.name(), &[self.min_categorical()]) + } + fn max_as_series(&self) -> Series { + Series::new(self.name(), &[self.max_categorical()]) + } +} + impl BinaryChunked { pub(crate) fn max_binary(&self) -> Option<&[u8]> { if self.is_empty() { @@ -542,8 +650,6 @@ impl ChunkAggSeries for ObjectChunked {} #[cfg(test)] mod test { - use arrow::legacy::prelude::QuantileInterpolOptions; - use crate::prelude::*; #[test] diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index 691f66e568c4..ce528337f0c1 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -1,5 +1,3 @@ -use arrow::legacy::prelude::QuantileInterpolOptions; - use super::*; pub trait QuantileAggSeries { diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs index 98c547143178..7182bcacb01a 100644 --- a/crates/polars-core/src/chunked_array/ops/any_value.rs +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -205,7 +205,7 @@ macro_rules! get_any_value { if $index >= $self.len() { polars_bail!(oob = $index, $self.len()); } - // SAFETY + // SAFETY: // bounds are checked Ok(unsafe { $self.get_any_value_unchecked($index) }) }}; diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index b70aa05c8970..c4d77df7e09e 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -1,14 +1,11 @@ //! Implementations of the ChunkApply Trait. use std::borrow::Cow; -use std::convert::TryFrom; -use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::utils::{get_bit_unchecked, set_bit_unchecked}; use arrow::legacy::bitmap::unary_mut; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::CustomIterTools; impl ChunkedArray where @@ -214,7 +211,7 @@ impl ChunkedArray { where F: Fn(T::Native) -> T::Native + Copy, { - // safety, we do no t change the lengths + // SAFETY, we do no t change the lengths unsafe { self.downcast_iter_mut() .for_each(|arr| arrow::compute::arity_assign::unary(arr, f)) @@ -281,7 +278,7 @@ where let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.into_iter().for_each(|opt_val| { - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val.copied(), item); @@ -371,7 +368,7 @@ impl<'a> ChunkApply<'a, bool> for BooleanChunked { let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.into_iter().for_each(|opt_val| { - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val, item); @@ -457,7 +454,7 @@ impl<'a> ChunkApply<'a, &'a str> for StringChunked { let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.into_iter().for_each(|opt_val| { - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val, item); @@ -500,7 +497,7 @@ impl<'a> ChunkApply<'a, &'a [u8]> for BinaryChunked { let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.into_iter().for_each(|opt_val| { - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val, item); @@ -663,7 +660,7 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { arr.iter().for_each(|opt_val| { let opt_val = opt_val.map(|arrayref| Series::try_from(("", arrayref)).unwrap()); - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val, item); @@ -713,7 +710,7 @@ where let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.into_iter().for_each(|opt_val| { - // Safety: + // SAFETY: // length asserted above let item = unsafe { slice.get_unchecked_mut(idx) }; *item = f(opt_val, item); diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index d0a298a46e03..cafdc8694182 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -1,12 +1,12 @@ use std::error::Error; -use arrow::array::Array; +use arrow::array::{Array, StaticArray}; use arrow::compute::utils::combine_validities_and; use polars_error::PolarsResult; -use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray}; +use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter}; use crate::prelude::{ChunkedArray, PolarsDataType, Series}; -use crate::utils::{align_chunks_binary, align_chunks_ternary}; +use crate::utils::{align_chunks_binary, align_chunks_binary_owned, align_chunks_ternary}; // We need this helper because for<'a> notation can't yet be applied properly // on the return type. @@ -38,6 +38,33 @@ impl R> BinaryFnMut for T { type Ret = R; } +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn unary_kernel(ca: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array) -> Arr, +{ + let iter = ca.downcast_iter().map(op); + ChunkedArray::from_chunk_iter(ca.name(), iter) +} + +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn unary_kernel_owned(ca: ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(T::Array) -> Arr, +{ + let name = ca.name().to_owned(); + let iter = ca.downcast_into_iter().map(op); + ChunkedArray::from_chunk_iter(&name, iter) +} + #[inline] pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray, mut op: F) -> ChunkedArray where @@ -435,6 +462,28 @@ where binary_mut_with_options(lhs, rhs, op, lhs.name()) } +/// Applies a kernel that produces `Array` types. +pub fn binary_owned( + lhs: ChunkedArray, + rhs: ChunkedArray, + mut op: F, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(L::Array, R::Array) -> Arr, +{ + let name = lhs.name().to_owned(); + let (lhs, rhs) = align_chunks_binary_owned(lhs, rhs); + let iter = lhs + .downcast_into_iter() + .zip(rhs.downcast_into_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); + ChunkedArray::from_chunk_iter(&name, iter) +} + /// Applies a kernel that produces `Array` types. pub fn try_binary( lhs: &ChunkedArray, @@ -725,3 +774,91 @@ where _ => try_binary_elementwise_values(lhs, rhs, op), } } + +pub fn apply_binary_kernel_broadcast<'l, 'r, L, R, O, K, LK, RK>( + lhs: &'l ChunkedArray, + rhs: &'r ChunkedArray, + kernel: K, + lhs_broadcast_kernel: LK, + rhs_broadcast_kernel: RK, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + O: PolarsDataType, + K: Fn(&L::Array, &R::Array) -> O::Array, + LK: Fn(L::Physical<'l>, &R::Array) -> O::Array, + RK: Fn(&L::Array, R::Physical<'r>) -> O::Array, +{ + let name = lhs.name(); + let out = match (lhs.len(), rhs.len()) { + (a, b) if a == b => binary(lhs, rhs, |lhs, rhs| kernel(lhs, rhs)), + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => { + let arr = O::Array::full_null(lhs.len(), O::get_dtype().to_arrow(true)); + ChunkedArray::::with_chunk(lhs.name(), arr) + }, + Some(rhs) => unary_kernel(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), + } + }, + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => { + let arr = O::Array::full_null(rhs.len(), O::get_dtype().to_arrow(true)); + ChunkedArray::::with_chunk(lhs.name(), arr) + }, + Some(lhs) => unary_kernel(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), + } + }, + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + out.with_name(name) +} + +pub fn apply_binary_kernel_broadcast_owned( + lhs: ChunkedArray, + rhs: ChunkedArray, + kernel: K, + lhs_broadcast_kernel: LK, + rhs_broadcast_kernel: RK, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + O: PolarsDataType, + K: Fn(L::Array, R::Array) -> O::Array, + for<'a> LK: Fn(L::Physical<'a>, R::Array) -> O::Array, + for<'a> RK: Fn(L::Array, R::Physical<'a>) -> O::Array, +{ + let name = lhs.name().to_owned(); + let out = match (lhs.len(), rhs.len()) { + (a, b) if a == b => binary_owned(lhs, rhs, kernel), + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => { + let arr = O::Array::full_null(lhs.len(), O::get_dtype().to_arrow(true)); + ChunkedArray::::with_chunk(lhs.name(), arr) + }, + Some(rhs) => unary_kernel_owned(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), + } + }, + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => { + let arr = O::Array::full_null(rhs.len(), O::get_dtype().to_arrow(true)); + ChunkedArray::::with_chunk(lhs.name(), arr) + }, + Some(lhs) => unary_kernel_owned(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), + } + }, + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + out.with_name(&name) +} diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index 2a589fded75c..3b4ed61f6ea7 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "object")] -use arrow::array::Array; use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; use polars_error::constants::LENGTH_LIMIT_MSG; @@ -37,7 +35,7 @@ pub(crate) fn slice( debug_assert!(remaining_offset + take_len <= chunk.len()); unsafe { - // Safety: + // SAFETY: // this function ensures the slices are in bounds new_chunks.push(chunk.sliced_unchecked(remaining_offset, take_len)); } diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index ccbca41e258a..02981d585144 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -4,6 +4,7 @@ use std::cmp::Ordering; use crate::chunked_array::ChunkedArrayLayout; use crate::prelude::*; +use crate::series::implementations::null::NullChunked; #[repr(transparent)] struct NonNull(T); @@ -64,12 +65,24 @@ where } } +impl TotalEqInner for &NullChunked { + unsafe fn eq_element_unchecked(&self, _idx_a: usize, _idx_b: usize) -> bool { + true + } +} + /// Create a type that implements TotalEqInner. pub(crate) trait IntoTotalEqInner<'a> { /// Create a type that implements `TakeRandom`. fn into_total_eq_inner(self) -> Box; } +impl<'a> IntoTotalEqInner<'a> for &'a NullChunked { + fn into_total_eq_inner(self) -> Box { + Box::new(self) + } +} + /// We use a trait object because we want to call this from Series and cannot use a typed enum. impl<'a, T> IntoTotalEqInner<'a> for &'a ChunkedArray where diff --git a/crates/polars-core/src/chunked_array/ops/downcast.rs b/crates/polars-core/src/chunked_array/ops/downcast.rs index f3303a08005a..435c43f82ca3 100644 --- a/crates/polars-core/src/chunked_array/ops/downcast.rs +++ b/crates/polars-core/src/chunked_array/ops/downcast.rs @@ -48,6 +48,16 @@ impl<'a, T> Chunks<'a, T> { #[doc(hidden)] impl ChunkedArray { + #[inline] + pub fn downcast_into_iter(mut self) -> impl DoubleEndedIterator { + let chunks = std::mem::take(&mut self.chunks); + chunks.into_iter().map(|arr| { + // SAFETY: T::Array guarantees this is correct. + let ptr = Box::into_raw(arr).cast::(); + unsafe { *Box::from_raw(ptr) } + }) + } + #[inline] pub fn downcast_iter(&self) -> impl DoubleEndedIterator { self.chunks.iter().map(|arr| { @@ -57,6 +67,19 @@ impl ChunkedArray { }) } + #[inline] + pub fn downcast_slices(&self) -> Option]>> { + if self.null_count != 0 { + return None; + } + let arr = self.downcast_iter().next().unwrap(); + if arr.as_slice().is_some() { + Some(self.downcast_iter().map(|arr| arr.as_slice().unwrap())) + } else { + None + } + } + /// # Safety /// The caller must ensure: /// * the length remains correct. diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 5e54a5fda1ad..6107b10ab696 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -1,9 +1,6 @@ -use std::convert::TryFrom; - use arrow::array::*; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::legacy::array::list::AnonymousBuilder; -use arrow::legacy::array::PolarsArray; use arrow::legacy::bit_util::unset_bit_raw; #[cfg(feature = "dtype-array")] use arrow::legacy::is_valid::IsValid; @@ -85,7 +82,7 @@ where new_values.extend_from_slice(values.get_unchecked(start..last)) }; - // Safety: + // SAFETY: // we are in bounds unsafe { unset_nulls( @@ -107,7 +104,7 @@ where } // final null check - // Safety: + // SAFETY: // we are in bounds unsafe { unset_nulls( @@ -252,9 +249,9 @@ impl ExplodeByOffsets for ListChunked { unsafe { // we create a pointer to evade the bck let ptr = arr.as_ref() as *const dyn Array; - // safety: we preallocated + // SAFETY: we preallocated owned.push_unchecked(arr); - // safety: the pointer is still valid as `owned` will not reallocate + // SAFETY: the pointer is still valid as `owned` will not reallocate builder.push(&*ptr as &dyn Array); } }, diff --git a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs index 8f91a8d2ab4f..65706ff40b52 100644 --- a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs +++ b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs @@ -1,7 +1,7 @@ use arrow::bitmap::MutableBitmap; use arrow::compute::cast::utf8view_to_utf8; #[cfg(feature = "dtype-array")] -use arrow::legacy::compute::take::take_unchecked; +use arrow::compute::take::take_unchecked; use polars_utils::vec::PushUnchecked; use super::*; @@ -34,11 +34,11 @@ impl ChunkExplode for ListChunked { if !offsets.is_empty() { let start = offsets[0] as usize; let len = offsets[offsets.len() - 1] as usize - start; - // safety: + // SAFETY: // we are in bounds values = unsafe { values.sliced_unchecked(start, len) }; } - // safety: inner_dtype should be correct + // SAFETY: inner_dtype should be correct unsafe { Series::from_chunks_and_dtype_unchecked( self.name(), @@ -64,7 +64,7 @@ impl ChunkExplode for ListChunked { } } - // safety: inner_dtype should be correct + // SAFETY: inner_dtype should be correct let values = unsafe { Series::from_chunks_and_dtype_unchecked( self.name(), @@ -96,7 +96,7 @@ impl ChunkExplode for ArrayChunked { i * width }) .collect::>(); - // safety: monotonically increasing + // SAFETY: monotonically increasing let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; return Ok(offsets); @@ -114,14 +114,14 @@ impl ChunkExplode for ArrayChunked { if i == 0 { return current_offset; } - // Safety: we are within bounds + // SAFETY: we are within bounds if unsafe { validity.get_bit_unchecked(i - 1) } { current_offset += width as i64 } current_offset }) .collect::>(); - // safety: monotonically increasing + // SAFETY: monotonically increasing let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; Ok(offsets) } @@ -141,7 +141,7 @@ impl ChunkExplode for ArrayChunked { i * width }) .collect::>(); - // safety: monotonically increasing + // SAFETY: monotonically increasing let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; return Ok((s, offsets)); } @@ -158,7 +158,7 @@ impl ChunkExplode for ArrayChunked { let mut current_offset = 0i64; offsets.push(current_offset); (0..arr.len()).for_each(|i| { - // Safety: we are within bounds + // SAFETY: we are within bounds if unsafe { validity.get_bit_unchecked(i) } { let start = (i * width) as IdxSize; let end = start + width as IdxSize; @@ -170,13 +170,13 @@ impl ChunkExplode for ArrayChunked { offsets.push(current_offset); }); - // Safety: the indices we generate are in bounds + // SAFETY: the indices we generate are in bounds let chunk = unsafe { take_unchecked(&**values, &indices.into()) }; - // safety: monotonically increasing + // SAFETY: monotonically increasing let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; Ok(( - // Safety: inner_dtype should be correct + // SAFETY: inner_dtype should be correct unsafe { Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], &ca.inner_dtype()) }, @@ -229,7 +229,7 @@ impl ChunkExplode for StringChunked { let mut bitmap = MutableBitmap::with_capacity(capacity); let values = values.as_slice(); for (&offset, valid) in old_offsets[1..].iter().zip(validity) { - // safety: + // SAFETY: // new_offsets already has a single value, so -1 is always in bounds let latest_offset = unsafe { *new_offsets.get_unchecked(new_offsets.len() - 1) }; @@ -240,7 +240,7 @@ impl ChunkExplode for StringChunked { // take the string value and find the char offsets // create a new offset value for each char boundary - // safety: + // SAFETY: // we know we have string data. let str_val = unsafe { std::str::from_utf8_unchecked(val) }; @@ -279,7 +279,7 @@ impl ChunkExplode for StringChunked { let values = values.as_slice(); for &offset in &old_offsets[1..] { - // safety: + // SAFETY: // new_offsets already has a single value, so -1 is always in bounds let latest_offset = unsafe { *new_offsets.get_unchecked(new_offsets.len() - 1) }; debug_assert!(old_offset as usize <= values.len()); @@ -288,7 +288,7 @@ impl ChunkExplode for StringChunked { // take the string value and find the char offsets // create a new offset value for each char boundary - // safety: + // SAFETY: // we know we have string data. let str_val = unsafe { std::str::from_utf8_unchecked(val) }; diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 9458021cf92d..e52cef3c296e 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -1,6 +1,6 @@ use arrow::legacy::kernels::set::set_at_nulls; use arrow::legacy::trusted_len::FromIteratorReversed; -use arrow::legacy::utils::{CustomIterTools, FromTrustedLenIterator}; +use arrow::legacy::utils::FromTrustedLenIterator; use num_traits::{Bounded, NumCast, One, Zero}; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index b07b9703b388..1ab76a60077f 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "object")] -use arrow::array::Array; use polars_compute::filter::filter as filter_fn; #[cfg(feature = "object")] diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index de1c9d734dc2..16ba9d5f0ba8 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -1,5 +1,4 @@ use arrow::bitmap::MutableBitmap; -use arrow::legacy::array::default_arrays::FromData; use crate::chunked_array::builder::get_list_builder; use crate::prelude::*; @@ -180,7 +179,7 @@ impl ListChunked { #[cfg(feature = "dtype-struct")] impl ChunkFullNull for StructChunked { fn full_null(name: &str, length: usize) -> StructChunked { - let s = vec![Series::full_null("", length, &DataType::Null)]; + let s = vec![Series::new_null("", length)]; StructChunked::new_unchecked(name, &s) } } diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index d017f4c45abe..53c0ee5c3546 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -1,14 +1,11 @@ -use arrow::array::Array; use arrow::bitmap::bitmask::BitMask; -use arrow::legacy::compute::take::take_unchecked; -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use arrow::compute::take::take_unchecked; +use polars_error::{polars_bail, polars_ensure}; use polars_utils::index::check_bounds; use crate::chunked_array::collect::prepare_collect_dtype; -use crate::chunked_array::ops::{ChunkTake, ChunkTakeUnchecked}; -use crate::chunked_array::ChunkedArray; -use crate::datatypes::{IdxCa, PolarsDataType, StaticArray}; use crate::prelude::*; +use crate::series::IsSorted; const BINARY_SEARCH_LIMIT: usize = 8; @@ -187,6 +184,18 @@ impl NotSpecialized for DecimalType {} #[cfg(feature = "object")] impl NotSpecialized for ObjectType {} +pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted { + use crate::series::IsSorted::*; + match (sorted_arr, sorted_idx) { + (_, Not) => Not, + (Not, _) => Not, + (Ascending, Ascending) => Ascending, + (Ascending, Descending) => Descending, + (Descending, Ascending) => Descending, + (Descending, Descending) => Ascending, + } +} + impl ChunkTakeUnchecked for ChunkedArray { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { @@ -233,16 +242,8 @@ impl ChunkTakeUnchecked for ChunkedAr }); let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag()); - use crate::series::IsSorted::*; - let sorted_flag = match (ca.is_sorted_flag(), indices.is_sorted_flag()) { - (_, Not) => Not, - (Not, _) => Not, - (Ascending, Ascending) => Ascending, - (Ascending, Descending) => Descending, - (Descending, Ascending) => Descending, - (Descending, Descending) => Ascending, - }; out.set_sorted_flag(sorted_flag); out } @@ -262,15 +263,8 @@ impl ChunkTakeUnchecked for BinaryChunked { let mut out = ChunkedArray::from_chunks(self.name(), chunks); - use crate::series::IsSorted::*; - let sorted_flag = match (self.is_sorted_flag(), indices.is_sorted_flag()) { - (_, Not) => Not, - (Not, _) => Not, - (Ascending, Ascending) => Ascending, - (Ascending, Descending) => Descending, - (Descending, Ascending) => Descending, - (Descending, Descending) => Ascending, - }; + let sorted_flag = + _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag()); out.set_sorted_flag(sorted_flag); out } diff --git a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs index 279a4ae0719f..bc33f088b1f9 100644 --- a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs +++ b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs @@ -1,4 +1,3 @@ -use crate::datatypes::PolarsNumericType; use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 113c28e9b3dd..6221cf0fad7d 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -1,9 +1,6 @@ //! Traits for miscellaneous operations on ChunkedArray -use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::offset::OffsetsBuffer; -#[cfg(feature = "object")] -use crate::datatypes::ObjectType; use crate::prelude::*; pub(crate) mod aggregate; @@ -35,7 +32,6 @@ pub(crate) mod rolling_window; mod set; mod shift; pub mod sort; -pub(crate) mod take; mod tile; #[cfg(feature = "algorithm_group_by")] pub(crate) mod unique; @@ -263,6 +259,10 @@ pub trait ChunkAgg { None } + fn min_max(&self) -> Option<(T, T)> { + Some((self.min()?, self.max()?)) + } + /// Returns the mean value in the array. /// Returns `None` if the array is empty or only contains null values. fn mean(&self) -> Option { diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index 085526476042..62d742fe6284 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -2,7 +2,7 @@ use crate::chunked_array::builder::get_fixed_size_list_builder; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; impl ChunkReverse for ChunkedArray where @@ -92,7 +92,7 @@ impl ChunkReverse for ArrayChunked { get_fixed_size_list_builder(&ca.inner_dtype(), ca.len(), ca.width(), ca.name()) .expect("not yet supported"); - // safety, we are within bounds + // SAFETY, we are within bounds unsafe { if arr.null_count() == 0 { for i in (0..arr.len()).rev() { diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index dff2e76e2616..2345f83df00d 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -30,7 +30,6 @@ impl Default for RollingOptionsFixedWindow { mod inner_mod { use std::ops::SubAssign; - use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::MutableBitmap; use arrow::legacy::bit_util::unset_bit_raw; use arrow::legacy::trusted_len::TrustedLenPush; @@ -107,7 +106,7 @@ mod inner_mod { if size < options.min_periods { builder.append_null(); } else { - // safety: + // SAFETY: // we are in bounds let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; @@ -117,7 +116,7 @@ mod inner_mod { continue; } - // Safety. + // SAFETY. // ptr is not dropped as we are in scope // We are also the only owner of the contents of the Arc // we do this to reduce heap allocs. @@ -161,7 +160,7 @@ mod inner_mod { if size < options.min_periods { builder.append_null(); } else { - // safety: + // SAFETY: // we are in bounds let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; @@ -171,7 +170,7 @@ mod inner_mod { continue; } - // Safety. + // SAFETY. // ptr is not dropped as we are in scope // We are also the only owner of the contents of the Arc // we do this to reduce heap allocs. diff --git a/crates/polars-core/src/chunked_array/ops/set.rs b/crates/polars-core/src/chunked_array/ops/set.rs index 0c9cdbd0f4aa..52646925a05c 100644 --- a/crates/polars-core/src/chunked_array/ops/set.rs +++ b/crates/polars-core/src/chunked_array/ops/set.rs @@ -1,9 +1,8 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask}; -use arrow::legacy::prelude::FromData; use crate::prelude::*; -use crate::utils::{align_chunks_binary, CustomIterTools}; +use crate::utils::align_chunks_binary; macro_rules! impl_scatter_with { ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{ diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 20f285169343..c1e2fe379155 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -4,7 +4,6 @@ use polars_utils::iter::EnumerateIdxTrait; use super::*; #[cfg(feature = "dtype-struct")] use crate::utils::_split_offsets; -use crate::POOL; pub(crate) fn args_validate( ca: &ChunkedArray, @@ -89,17 +88,15 @@ pub(crate) fn encode_rows_vertical(by: &[Series]) -> PolarsResult> = splits - .into_par_iter() - .map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded(&sliced, &descending, false)?; - Ok(rows.into_array()) - }) - .collect(); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded(&sliced, &descending, false)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); Ok(BinaryOffsetChunked::from_chunk_iter("", chunks?)) } diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 6359017443f1..28c5c2962616 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -26,7 +26,7 @@ impl CategoricalChunked { .map(|(idx, _v)| idx) .collect_ca_trusted(self.name()); - // safety: + // SAFETY: // we only reordered the indexes so we are still in bounds return unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -38,7 +38,7 @@ impl CategoricalChunked { }; } let cats = self.physical().sort_with(options); - // safety: + // SAFETY: // we only reordered the indexes so we are still in bounds unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 3180e616e1a2..7a0684295111 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -7,10 +7,8 @@ mod categorical; use std::cmp::Ordering; pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt; -use arrow::array::ValueSize; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; -use arrow::legacy::prelude::FromData; use arrow::legacy::trusted_len::TrustedLenPush; use rayon::prelude::*; pub use slice::*; @@ -21,7 +19,7 @@ use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; use crate::prelude::sort::arg_sort_multiple::{arg_sort_multiple_impl, args_validate}; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; use crate::POOL; pub(crate) fn sort_by_branch(slice: &mut [T], descending: bool, cmp: C, parallel: bool) @@ -275,7 +273,7 @@ fn ordering_other_columns<'a>( idx_b: usize, ) -> Ordering { for (cmp, descending) in compare_inner.iter().zip(descending) { - // Safety: + // SAFETY: // indices are in bounds let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b) }; match (ordering, descending) { @@ -619,6 +617,7 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult #[cfg(feature = "dtype-categorical")] Categorical(_, _) | Enum(_, _) => s.rechunk(), Binary | Boolean => s.clone(), + BinaryOffset => s.clone(), String => s.cast(&Binary).unwrap(), #[cfg(feature = "dtype-struct")] Struct(_) => { diff --git a/crates/polars-core/src/chunked_array/ops/take/mod.rs b/crates/polars-core/src/chunked_array/ops/take/mod.rs deleted file mode 100644 index ccb11d118ba3..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! Traits to provide fast Random access to ChunkedArrays data. -//! This prevents downcasting every iteration. - -use crate::prelude::*; -use crate::utils::NoNull; - -mod take_chunked; -#[cfg(feature = "chunked_ids")] -pub(crate) use take_chunked::*; diff --git a/crates/polars-core/src/chunked_array/ops/take/take_chunked.rs b/crates/polars-core/src/chunked_array/ops/take/take_chunked.rs deleted file mode 100644 index e977b7b5f1c0..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_chunked.rs +++ /dev/null @@ -1,273 +0,0 @@ -use polars_utils::slice::GetSaferUnchecked; - -use super::*; -use crate::series::IsSorted; - -pub trait TakeChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self; - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self; -} - -impl TakeChunked for ChunkedArray -where - T: PolarsNumericType, -{ - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let mut ca = if self.null_count() == 0 { - let arrs = self - .downcast_iter() - .map(|arr| arr.values().as_slice()) - .collect::>(); - - let ca: NoNull = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked_release(*chunk_idx as usize); - *arr.get_unchecked_release(*array_idx as usize) - }) - .collect_trusted(); - - ca.into_inner() - } else { - let arrs = self.downcast_iter().collect::>(); - by.iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }) - .collect_trusted() - }; - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked_release(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }) - .collect_trusted(); - - ca.rename(self.name()); - ca - } -} - -impl TakeChunked for StringChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - self.as_binary() - .take_chunked_unchecked(by, sorted) - .to_string() - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - self.as_binary().take_opt_chunked_unchecked(by).to_string() - } -} - -impl TakeChunked for BinaryOffsetChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }) - .collect_trusted(); - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }) - .collect_trusted(); - - ca.rename(self.name()); - ca - } -} - -impl TakeChunked for BinaryChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }) - .collect_trusted(); - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }) - .collect_trusted(); - - ca.rename(self.name()); - ca - } -} - -impl TakeChunked for BooleanChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }) - .collect_trusted(); - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }) - .collect_trusted(); - - ca.rename(self.name()); - ca - } -} - -impl TakeChunked for ListChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }) - .collect(); - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }) - .collect(); - - ca.rename(self.name()); - ca - } -} - -#[cfg(feature = "dtype-array")] -impl TakeChunked for ArrayChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - let iter = by.iter().map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize) - }); - let mut ca = Self::from_iter_and_args( - iter, - self.width(), - by.len(), - Some(self.inner_dtype()), - self.name(), - ); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let iter = by.iter().map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize) - }) - }); - - Self::from_iter_and_args( - iter, - self.width(), - by.len(), - Some(self.inner_dtype()), - self.name(), - ) - } -} -#[cfg(feature = "object")] -impl TakeChunked for ObjectChunked { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let arrs = self.downcast_iter().collect::>(); - - let mut ca: Self = by - .iter() - .map(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(*chunk_idx as usize); - arr.get_unchecked(*array_idx as usize).cloned() - }) - .collect(); - - ca.rename(self.name()); - ca.set_sorted_flag(sorted); - ca - } - - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let arrs = self.downcast_iter().collect::>(); - let mut ca: Self = by - .iter() - .map(|opt_idx| { - opt_idx.and_then(|[chunk_idx, array_idx]| { - let arr = arrs.get_unchecked(chunk_idx as usize); - arr.get_unchecked(array_idx as usize).cloned() - }) - }) - .collect(); - - ca.rename(self.name()); - ca - } -} diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 27b10a659dd6..382047578682 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -1,11 +1,8 @@ use std::hash::Hash; use arrow::bitmap::MutableBitmap; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; -#[cfg(feature = "object")] -use crate::datatypes::ObjectType; -use crate::datatypes::PlHashSet; -use crate::frame::group_by::GroupsProxy; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::series::IsSorted; @@ -60,12 +57,13 @@ impl ChunkUnique> for ObjectChunked { fn arg_unique(a: impl Iterator, capacity: usize) -> Vec where - T: Hash + Eq, + T: ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut set = PlHashSet::new(); let mut unique = Vec::with_capacity(capacity); a.enumerate().for_each(|(idx, val)| { - if set.insert(val) { + if set.insert(val.to_total_ord()) { unique.push(idx as IdxSize) } }); @@ -76,15 +74,16 @@ macro_rules! arg_unique_ca { ($ca:expr) => {{ match $ca.has_validity() { false => arg_unique($ca.into_no_null_iter(), $ca.len()), - _ => arg_unique($ca.into_iter(), $ca.len()), + _ => arg_unique($ca.iter(), $ca.len()), } }}; } impl ChunkUnique for ChunkedArray where - T: PolarsIntegerType, - T::Native: Hash + Eq + Ord, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Ord, ChunkedArray: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, { fn unique(&self) -> PolarsResult { @@ -96,25 +95,23 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut arr = MutablePrimitiveArray::with_capacity(self.len()); - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - arr.push(val) - }; + if !self.is_empty() { + let mut iter = self.iter(); + let last = iter.next().unwrap(); + arr.push(last); + let mut last = last.to_total_ord(); - #[allow(clippy::unnecessary_filter_map)] - let to_extend = iter.filter_map(|opt_val| { - if opt_val != last { - last = opt_val; - Some(opt_val) - } else { - None - } - }); + let to_extend = iter.filter(|opt_val| { + let opt_val_tot_ord = opt_val.to_total_ord(); + let out = opt_val_tot_ord != last; + last = opt_val_tot_ord; + out + }); + + arr.extend(to_extend); + } - arr.extend(to_extend); let arr: PrimitiveArray = arr.into(); Ok(ChunkedArray::with_chunk(self.name(), arr)) } else { @@ -142,15 +139,18 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut count = 0; - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - count += 1; - }; + if self.is_empty() { + return Ok(count); + } + + let mut iter = self.iter(); + let mut last = iter.next().unwrap().to_total_ord(); + + count += 1; iter.for_each(|opt_val| { + let opt_val = opt_val.to_total_ord(); if opt_val != last { last = opt_val; count += 1; @@ -254,30 +254,6 @@ impl ChunkUnique for BooleanChunked { } } -impl ChunkUnique for Float32Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_small(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_small().arg_unique() - } -} - -impl ChunkUnique for Float64Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_large(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_large().arg_unique() - } -} - #[cfg(test)] mod test { use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 8033b4d80f2b..80b3bcdfd815 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -1,8 +1,7 @@ use arrow::compute::if_then_else::if_then_else; -use arrow::legacy::array::default_arrays::FromData; use crate::prelude::*; -use crate::utils::{align_chunks_ternary, CustomIterTools}; +use crate::utils::align_chunks_ternary; fn ternary_apply(predicate: bool, truthy: T, falsy: T) -> T { if predicate { diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 7476183eab09..18b1117669fc 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -3,12 +3,12 @@ use polars_error::to_compute_err; use rand::distributions::Bernoulli; use rand::prelude::*; use rand::seq::index::IndexVec; -use rand_distr::{Distribution, Normal, Standard, StandardNormal, Uniform}; +use rand_distr::{Normal, Standard, StandardNormal, Uniform}; use crate::prelude::DataType::Float64; use crate::prelude::*; use crate::random::get_global_random_u64; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { if len == 0 { @@ -194,7 +194,7 @@ impl DataFrame { Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), None => { let new_cols = self.columns.iter().map(Series::clear).collect_trusted(); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) }, } } @@ -239,7 +239,7 @@ impl DataFrame { }, None => { let new_cols = self.columns.iter().map(Series::clear).collect_trusted(); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) }, } } diff --git a/crates/polars-core/src/chunked_array/temporal/date.rs b/crates/polars-core/src/chunked_array/temporal/date.rs index 7f6146fa921b..f517f5acc8ef 100644 --- a/crates/polars-core/src/chunked_array/temporal/date.rs +++ b/crates/polars-core/src/chunked_array/temporal/date.rs @@ -13,7 +13,7 @@ pub(crate) fn naive_date_to_date(nd: NaiveDate) -> i32 { impl DateChunked { pub fn as_date_iter(&self) -> impl TrustedLen> + '_ { - // safety: we know the iterators len + // SAFETY: we know the iterators len unsafe { self.downcast_iter() .flat_map(|iter| { diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index 3fcea28a66ce..bd3e6fae1c47 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -4,16 +4,10 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, }; use chrono::format::{DelayedFormat, StrftimeItems}; -use chrono::NaiveDate; #[cfg(feature = "timezones")] use chrono::TimeZone as TimeZoneTrait; -#[cfg(feature = "timezones")] -use chrono_tz::Tz; -use super::conversion::{datetime_to_timestamp_ms, datetime_to_timestamp_ns}; use super::*; -#[cfg(feature = "timezones")] -use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; use crate::prelude::*; @@ -176,12 +170,12 @@ impl DatetimeChunked { use TimeUnit::*; match (current_unit, tu) { (Nanoseconds, Microseconds) => { - let ca = &self.0 / 1_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); out.0 = ca; out }, (Nanoseconds, Milliseconds) => { - let ca = &self.0 / 1_000_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000_000); out.0 = ca; out }, @@ -191,7 +185,7 @@ impl DatetimeChunked { out }, (Microseconds, Milliseconds) => { - let ca = &self.0 / 1_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); out.0 = ca; out }, @@ -223,17 +217,6 @@ impl DatetimeChunked { self.2 = Some(Datetime(self.time_unit(), Some(time_zone))); Ok(()) } - #[cfg(feature = "timezones")] - pub fn convert_time_zone(mut self, time_zone: TimeZone) -> PolarsResult { - polars_ensure!( - self.time_zone().is_some(), - InvalidOperation: - "cannot call `convert_time_zone` on tz-naive; \ - set a time zone first with `replace_time_zone`" - ); - self.set_time_zone(time_zone)?; - Ok(self) - } } #[cfg(test)] diff --git a/crates/polars-core/src/chunked_array/temporal/duration.rs b/crates/polars-core/src/chunked_array/temporal/duration.rs index 7258cb83326e..7c649e3178b0 100644 --- a/crates/polars-core/src/chunked_array/temporal/duration.rs +++ b/crates/polars-core/src/chunked_array/temporal/duration.rs @@ -20,12 +20,12 @@ impl DurationChunked { use TimeUnit::*; match (current_unit, tu) { (Nanoseconds, Microseconds) => { - let ca = &self.0 / 1_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); out.0 = ca; out }, (Nanoseconds, Milliseconds) => { - let ca = &self.0 / 1_000_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000_000); out.0 = ca; out }, @@ -35,7 +35,7 @@ impl DurationChunked { out }, (Microseconds, Milliseconds) => { - let ca = &self.0 / 1_000; + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); out.0 = ca; out }, diff --git a/crates/polars-core/src/chunked_array/trusted_len.rs b/crates/polars-core/src/chunked_array/trusted_len.rs index a241e0432569..baa473cc07e1 100644 --- a/crates/polars-core/src/chunked_array/trusted_len.rs +++ b/crates/polars-core/src/chunked_array/trusted_len.rs @@ -4,7 +4,7 @@ use arrow::legacy::trusted_len::{FromIteratorReversed, TrustedLenPush}; use crate::chunked_array::upstream_traits::PolarsAsRef; use crate::prelude::*; -use crate::utils::{CustomIterTools, FromTrustedLenIterator, NoNull}; +use crate::utils::{FromTrustedLenIterator, NoNull}; impl FromTrustedLenIterator> for ChunkedArray where @@ -193,7 +193,6 @@ impl FromTrustedLenIterator> for ObjectChunked { #[cfg(test)] mod test { use super::*; - use crate::utils::CustomIterTools; #[test] fn test_reverse_collect() { diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index 69fc84b847a0..ce0fbcf4ad7b 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -1,14 +1,10 @@ //! Implementations of upstream traits for [`ChunkedArray`] use std::borrow::{Borrow, Cow}; use std::collections::LinkedList; -use std::iter::FromIterator; use std::marker::PhantomData; -use std::sync::Arc; -use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; use polars_utils::sync::SyncPtr; -use rayon::iter::{FromParallelIterator, IntoParallelIterator}; use rayon::prelude::*; use crate::chunked_array::builder::{ @@ -22,7 +18,7 @@ use crate::chunked_array::object::builder::get_object_type; use crate::chunked_array::object::ObjectArray; use crate::prelude::*; use crate::utils::flatten::flatten_par; -use crate::utils::{get_iter_capacity, CustomIterTools, NoNull}; +use crate::utils::{get_iter_capacity, NoNull}; impl Default for ChunkedArray { fn default() -> Self { @@ -278,7 +274,10 @@ impl FromIterator>> for ListChunked { #[cfg(feature = "dtype-array")] impl ArrayChunked { - pub(crate) unsafe fn from_iter_and_args>>>( + /// # Safety + /// The caller must ensure that the underlying `Arrays` match the given datatype. + /// That means the logical map should map to the physical type. + pub unsafe fn from_iter_and_args>>( iter: I, width: usize, capacity: usize, @@ -609,42 +608,35 @@ where } } -/// From trait -impl<'a> From<&'a StringChunked> for Vec> { - fn from(ca: &'a StringChunked) -> Self { - ca.into_iter().collect() +impl<'a, T> From<&'a ChunkedArray> for Vec>> +where + T: PolarsDataType, +{ + fn from(ca: &'a ChunkedArray) -> Self { + let mut out = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + out.extend(arr.iter()) + } + out } } - impl From for Vec> { fn from(ca: StringChunked) -> Self { - ca.into_iter() - .map(|opt| opt.map(|s| s.to_string())) - .collect() - } -} - -impl<'a> From<&'a BooleanChunked> for Vec> { - fn from(ca: &'a BooleanChunked) -> Self { - ca.into_iter().collect() + ca.iter().map(|opt| opt.map(|s| s.to_string())).collect() } } impl From for Vec> { fn from(ca: BooleanChunked) -> Self { - ca.into_iter().collect() - } -} - -impl<'a, T> From<&'a ChunkedArray> for Vec> -where - T: PolarsNumericType, -{ - fn from(ca: &'a ChunkedArray) -> Self { - ca.into_iter().collect() + let mut out = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + out.extend(arr.iter()) + } + out } } +/// From trait impl FromParallelIterator> for ListChunked { fn from_par_iter(iter: I) -> Self where diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index 69642e749796..922ba5b95b3f 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -4,8 +4,8 @@ //! We could use [serde_1712](https://github.com/serde-rs/serde/issues/1712), but that gave problems caused by //! [rust_96956](https://github.com/rust-lang/rust/issues/96956), so we make a dummy type without static -use serde::de::{SeqAccess, Visitor}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::de::SeqAccess; +use serde::{Deserialize, Serialize}; use super::*; diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs index d5ce2da0974b..42ecbd018bdf 100644 --- a/crates/polars-core/src/datatypes/aliases.rs +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -4,9 +4,6 @@ pub use polars_utils::aliases::{InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, use super::*; use crate::hashing::IdBuildHasher; -/// [ChunkIdx, DfIdx] -pub type ChunkId = [IdxSize; 2]; - #[cfg(not(feature = "bigidx"))] pub type IdxCa = UInt32Chunked; #[cfg(feature = "bigidx")] diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 0c6ecc8058c1..523b7a9939d4 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -1,15 +1,12 @@ #[cfg(feature = "dtype-struct")] use arrow::legacy::trusted_len::TrustedLenPush; -#[cfg(feature = "dtype-date")] -use arrow::temporal_conversions::{ - timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, -}; use arrow::types::PrimitiveType; use polars_utils::format_smartstring; #[cfg(feature = "dtype-struct")] use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "dtype-categorical")] use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::ToTotalOrd; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; @@ -200,7 +197,7 @@ impl<'a> Deserialize<'a> for AnyValue<'static> { ) })?; - // safety: + // SAFETY: // we are repr: u8 and check last value that we are in bounds let field = unsafe { if field <= LAST { @@ -342,10 +339,15 @@ impl<'a> Deserialize<'a> for AnyValue<'static> { } impl<'a> AnyValue<'a> { + /// Get the matching [`DataType`] for this [`AnyValue`]`. + /// + /// Note: For `Categorical` and `Enum` values, the exact mapping information + /// is not preserved in the result for performance reasons. pub fn dtype(&self) -> DataType { use AnyValue::*; - match self.as_borrowed() { - Null => DataType::Unknown, + match self { + Null => DataType::Null, + Boolean(_) => DataType::Boolean, Int8(_) => DataType::Int8, Int16(_) => DataType::Int16, Int32(_) => DataType::Int32, @@ -356,29 +358,36 @@ impl<'a> AnyValue<'a> { UInt64(_) => DataType::UInt64, Float32(_) => DataType::Float32, Float64(_) => DataType::Float64, + String(_) | StringOwned(_) => DataType::String, + Binary(_) | BinaryOwned(_) => DataType::Binary, #[cfg(feature = "dtype-date")] Date(_) => DataType::Date, - #[cfg(feature = "dtype-datetime")] - Datetime(_, tu, tz) => DataType::Datetime(tu, tz.clone()), #[cfg(feature = "dtype-time")] Time(_) => DataType::Time, + #[cfg(feature = "dtype-datetime")] + Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()), #[cfg(feature = "dtype-duration")] - Duration(_, tu) => DataType::Duration(tu), - Boolean(_) => DataType::Boolean, - String(_) => DataType::String, + Duration(_, tu) => DataType::Duration(*tu), #[cfg(feature = "dtype-categorical")] Categorical(_, _, _) => DataType::Categorical(None, Default::default()), #[cfg(feature = "dtype-categorical")] Enum(_, _, _) => DataType::Enum(None, Default::default()), List(s) => DataType::List(Box::new(s.dtype().clone())), + #[cfg(feature = "dtype-array")] + Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size), #[cfg(feature = "dtype-struct")] Struct(_, _, fields) => DataType::Struct(fields.to_vec()), #[cfg(feature = "dtype-struct")] StructOwned(payload) => DataType::Struct(payload.1.clone()), - Binary(_) => DataType::Binary, - _ => unimplemented!(), + #[cfg(feature = "dtype-decimal")] + Decimal(_, scale) => DataType::Decimal(None, Some(*scale)), + #[cfg(feature = "object")] + Object(o) => DataType::Object(o.type_name(), None), + #[cfg(feature = "object")] + ObjectOwned(o) => DataType::Object(o.0.type_name(), None), } } + /// Extract a numerical value from the AnyValue #[doc(hidden)] #[inline] @@ -439,13 +448,17 @@ impl<'a> AnyValue<'a> { } pub fn is_numeric(&self) -> bool { - self.is_signed_integer() || self.is_unsigned_integer() || self.is_float() + self.is_integer() || self.is_float() } pub fn is_float(&self) -> bool { matches!(self, AnyValue::Float32(_) | AnyValue::Float64(_)) } + pub fn is_integer(&self) -> bool { + self.is_signed_integer() || self.is_unsigned_integer() + } + pub fn is_signed_integer(&self) -> bool { matches!( self, @@ -460,142 +473,148 @@ impl<'a> AnyValue<'a> { ) } - pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult> { - fn cast_numeric<'a>(av: &AnyValue, dtype: &'a DataType) -> PolarsResult> { - Ok(match dtype { - DataType::UInt8 => AnyValue::UInt8(av.try_extract::()?), - DataType::UInt16 => AnyValue::UInt16(av.try_extract::()?), - DataType::UInt32 => AnyValue::UInt32(av.try_extract::()?), - DataType::UInt64 => AnyValue::UInt64(av.try_extract::()?), - DataType::Int8 => AnyValue::Int8(av.try_extract::()?), - DataType::Int16 => AnyValue::Int16(av.try_extract::()?), - DataType::Int32 => AnyValue::Int32(av.try_extract::()?), - DataType::Int64 => AnyValue::Int64(av.try_extract::()?), - DataType::Float32 => AnyValue::Float32(av.try_extract::()?), - DataType::Float64 => AnyValue::Float64(av.try_extract::()?), - _ => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype) - }, - }) - } + pub fn is_null(&self) -> bool { + matches!(self, AnyValue::Null) + } - fn cast_boolean<'a>(av: &AnyValue) -> PolarsResult> { - Ok(match av { - AnyValue::UInt8(v) => AnyValue::Boolean(*v != u8::default()), - AnyValue::UInt16(v) => AnyValue::Boolean(*v != u16::default()), - AnyValue::UInt32(v) => AnyValue::Boolean(*v != u32::default()), - AnyValue::UInt64(v) => AnyValue::Boolean(*v != u64::default()), - AnyValue::Int8(v) => AnyValue::Boolean(*v != i8::default()), - AnyValue::Int16(v) => AnyValue::Boolean(*v != i16::default()), - AnyValue::Int32(v) => AnyValue::Boolean(*v != i32::default()), - AnyValue::Int64(v) => AnyValue::Boolean(*v != i64::default()), - AnyValue::Float32(v) => AnyValue::Boolean(*v != f32::default()), - AnyValue::Float64(v) => AnyValue::Boolean(*v != f64::default()), - _ => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to boolean", av) - }, - }) + pub fn is_nested_null(&self) -> bool { + match self { + AnyValue::Null => true, + AnyValue::List(s) => s.null_count() == s.len(), + #[cfg(feature = "dtype-struct")] + AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()), + _ => false, } + } - let new_av = match self { - _ if (self.is_boolean() - | self.is_signed_integer() - | self.is_unsigned_integer() - | self.is_float()) => - { - match dtype { - #[cfg(feature = "dtype-date")] - DataType::Date => AnyValue::Date(self.try_extract::()?), - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { - AnyValue::Datetime(self.try_extract::()?, *tu, tz) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(tu) => AnyValue::Duration(self.try_extract::()?, *tu), - #[cfg(feature = "dtype-time")] - DataType::Time => AnyValue::Time(self.try_extract::()?), - DataType::String => { - AnyValue::StringOwned(format_smartstring!("{}", self.try_extract::()?)) - }, - DataType::Boolean => return cast_boolean(self), - _ => return cast_numeric(self, dtype), - } + /// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`, + /// if possible. + /// + pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult> { + let new_av = match (self, dtype) { + // to numeric + (av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::()?), + (av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::()?), + (av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::()?), + (av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::()?), + (av, DataType::Int8) => AnyValue::Int8(av.try_extract::()?), + (av, DataType::Int16) => AnyValue::Int16(av.try_extract::()?), + (av, DataType::Int32) => AnyValue::Int32(av.try_extract::()?), + (av, DataType::Int64) => AnyValue::Int64(av.try_extract::()?), + (av, DataType::Float32) => AnyValue::Float32(av.try_extract::()?), + (av, DataType::Float64) => AnyValue::Float64(av.try_extract::()?), + + // to boolean + (AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()), + (AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()), + (AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()), + (AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()), + (AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()), + (AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()), + (AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()), + (AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()), + (AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()), + (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), + + // to string + (av, DataType::String) => { + AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::()?)) }, + + // to binary + (AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()), + + // to datetime #[cfg(feature = "dtype-datetime")] - AnyValue::Datetime(v, tu, None) => match dtype { - #[cfg(feature = "dtype-date")] - // Datetime to Date - DataType::Date => { - let convert = match tu { - TimeUnit::Nanoseconds => timestamp_ns_to_datetime, - TimeUnit::Microseconds => timestamp_us_to_datetime, - TimeUnit::Milliseconds => timestamp_ms_to_datetime, - }; - let ndt = convert(*v); - let date_value = naive_datetime_to_date(ndt); - AnyValue::Date(date_value) + (av, DataType::Datetime(tu, tz)) if av.is_numeric() => { + AnyValue::Datetime(av.try_extract::()?, *tu, tz) + }, + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime( + match tu { + TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY, + TimeUnit::Microseconds => (*v as i64) * US_IN_DAY, + TimeUnit::Milliseconds => (*v as i64) * MS_IN_DAY, }, - #[cfg(feature = "dtype-time")] - // Datetime to Time - DataType::Time => { - let ns_since_midnight = match tu { - TimeUnit::Nanoseconds => *v % NS_IN_DAY, - TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64, - TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64, - }; - AnyValue::Time(ns_since_midnight) + *tu, + &None, + ), + #[cfg(feature = "dtype-datetime")] + (AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime( + match (tu, tu_r) { + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, }, - _ => return cast_numeric(self, dtype), - }, + *tu_r, + tz_r, + ), + + // to date + #[cfg(feature = "dtype-date")] + (av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.try_extract::()?), + #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu { + TimeUnit::Nanoseconds => *v / NS_IN_DAY, + TimeUnit::Microseconds => *v / US_IN_DAY, + TimeUnit::Milliseconds => *v / MS_IN_DAY, + } as i32), + + // to time + #[cfg(feature = "dtype-time")] + (av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.try_extract::()?), + #[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu { + TimeUnit::Nanoseconds => *v % NS_IN_DAY, + TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64, + TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64, + }), + + // to duration #[cfg(feature = "dtype-duration")] - AnyValue::Duration(v, _) => match dtype { - DataType::Time | DataType::Date | DataType::Datetime(_, _) => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", v, dtype) - }, - _ => return cast_numeric(self, dtype), + (av, DataType::Duration(tu)) if av.is_numeric() => { + AnyValue::Duration(av.try_extract::()?, *tu) }, - #[cfg(feature = "dtype-time")] - AnyValue::Time(v) => match dtype { - #[cfg(feature = "dtype-duration")] - // Time to Duration - DataType::Duration(tu) => { - let duration_value = match tu { - TimeUnit::Nanoseconds => *v, - TimeUnit::Microseconds => *v / 1_000i64, - TimeUnit::Milliseconds => *v / 1_000_000i64, - }; - AnyValue::Duration(duration_value, *tu) + #[cfg(all(feature = "dtype-duration", feature = "dtype-time"))] + (AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration( + match *tu { + TimeUnit::Nanoseconds => *v, + TimeUnit::Microseconds => *v / 1_000i64, + TimeUnit::Milliseconds => *v / 1_000_000i64, }, - _ => return cast_numeric(self, dtype), - }, - #[cfg(feature = "dtype-date")] - AnyValue::Date(v) => match dtype { - #[cfg(feature = "dtype-datetime")] - // Date to Datetime - DataType::Datetime(tu, None) => { - let ndt = arrow::temporal_conversions::date32_to_datetime(*v); - let func = match tu { - TimeUnit::Nanoseconds => datetime_to_timestamp_ns, - TimeUnit::Microseconds => datetime_to_timestamp_us, - TimeUnit::Milliseconds => datetime_to_timestamp_ms, - }; - let value = func(ndt); - AnyValue::Datetime(value, *tu, &None) + *tu, + ), + #[cfg(feature = "dtype-duration")] + (AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration( + match (tu, tu_r) { + (_, _) if tu == tu_r => *v, + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, }, - _ => return cast_numeric(self, dtype), - }, - AnyValue::String(s) if dtype == &DataType::Binary => AnyValue::Binary(s.as_bytes()), - _ => { - polars_bail!(ComputeError: "cannot cast any-value '{:?}' to '{:?}'", self.dtype(), dtype) - }, + *tu_r, + ), + + // to self + (av, dtype) if av.dtype() == *dtype => self.clone(), + + av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype), }; Ok(new_av) } - pub fn cast(&self, dtype: &'a DataType) -> PolarsResult> { + pub fn cast(&self, dtype: &'a DataType) -> AnyValue<'a> { match self.strict_cast(dtype) { - Ok(s) => Ok(s), - Err(_) => Ok(AnyValue::Null), + Ok(av) => av, + Err(_) => AnyValue::Null, } } } @@ -606,6 +625,12 @@ impl From> for DataType { } } +impl<'a> From<&AnyValue<'a>> for DataType { + fn from(value: &AnyValue<'a>) -> Self { + value.dtype() + } +} + impl AnyValue<'_> { pub fn hash_impl(&self, state: &mut H, cheap: bool) { use AnyValue::*; @@ -799,13 +824,13 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-struct")] StructOwned(payload) => { let av = StructOwned(payload); - // safety: owned is already static + // SAFETY: owned is already static unsafe { std::mem::transmute::, AnyValue<'static>>(av) } }, #[cfg(feature = "object")] ObjectOwned(payload) => { let av = ObjectOwned(payload); - // safety: owned is already static + // SAFETY: owned is already static unsafe { std::mem::transmute::, AnyValue<'static>>(av) } }, #[cfg(feature = "dtype-decimal")] @@ -833,16 +858,6 @@ impl<'a> AnyValue<'a> { _ => None, } } - - pub fn is_nested_null(&self) -> bool { - match self { - AnyValue::Null => true, - AnyValue::List(s) => s.dtype().is_nested_null(), - #[cfg(feature = "dtype-struct")] - AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()), - _ => false, - } - } } impl<'a> From> for Option { @@ -871,8 +886,8 @@ impl AnyValue<'_> { (Int16(l), Int16(r)) => *l == *r, (Int32(l), Int32(r)) => *l == *r, (Int64(l), Int64(r)) => *l == *r, - (Float32(l), Float32(r)) => *l == *r, - (Float64(l), Float64(r)) => *l == *r, + (Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(), + (Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(), (String(l), String(r)) => l == r, (String(l), StringOwned(r)) => l == r, (StringOwned(l), String(r)) => l == r, @@ -956,8 +971,8 @@ impl PartialOrd for AnyValue<'_> { (Int16(l), Int16(r)) => l.partial_cmp(r), (Int32(l), Int32(r)) => l.partial_cmp(r), (Int64(l), Int64(r)) => l.partial_cmp(r), - (Float32(l), Float32(r)) => l.partial_cmp(r), - (Float64(l), Float64(r)) => l.partial_cmp(r), + (Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), + (Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), (String(l), String(r)) => l.partial_cmp(*r), (Binary(l), Binary(r)) => l.partial_cmp(*r), _ => None, diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index ef29cbe8d85e..9fb5ea7f9ac9 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -1,6 +1,4 @@ use std::collections::BTreeMap; -use std::convert::Into; -use std::string::ToString; use super::*; #[cfg(feature = "object")] @@ -206,7 +204,7 @@ impl DataType { self.is_float() || self.is_integer() } - /// Check if this [`DataType`] is a basic numeric type (excludes Decimal). + /// Check if this [`DataType`] is a boolean pub fn is_bool(&self) -> bool { matches!(self, DataType::Boolean) } diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index 5c41347d3c97..8b0664b14168 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -147,6 +147,10 @@ impl DataType { ArrowDataType::Struct(fields) => { DataType::Struct(fields.iter().map(|fld| fld.into()).collect()) } + #[cfg(not(feature = "dtype-struct"))] + ArrowDataType::Struct(_) => { + panic!("activate the 'dtype-struct' feature to handle struct data types") + } ArrowDataType::Extension(name, _, _) if name == "POLARS_EXTENSION_TYPE" => { #[cfg(feature = "object")] { diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 5fc4b4427d2b..8c14f78379f8 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -35,6 +35,7 @@ use bytemuck::Zeroable; pub use dtype::*; pub use field::*; use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Zero}; +use polars_compute::arithmetic::HasPrimitiveArithmeticKernel; use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; @@ -47,7 +48,6 @@ use serde::{Deserialize, Serialize}; use serde::{Deserializer, Serializer}; pub use time_unit::*; -use crate::chunked_array::arithmetic::ArrayArithmetics; pub use crate::chunked_array::logical::*; #[cfg(feature = "object")] use crate::chunked_array::object::ObjectArray; @@ -263,44 +263,56 @@ pub trait NumericNative: + Bounded + FromPrimitive + IsFloat - + ArrayArithmetics + + HasPrimitiveArithmeticKernel::Native> + MinMax + IsNull { type PolarsType: PolarsNumericType; + type TrueDivPolarsType: PolarsNumericType; } impl NumericNative for i8 { type PolarsType = Int8Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for i16 { type PolarsType = Int16Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for i32 { type PolarsType = Int32Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for i64 { type PolarsType = Int64Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for u8 { type PolarsType = UInt8Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for u16 { type PolarsType = UInt16Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for u32 { type PolarsType = UInt32Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for u64 { type PolarsType = UInt64Type; + type TrueDivPolarsType = Float64Type; } #[cfg(feature = "dtype-decimal")] impl NumericNative for i128 { type PolarsType = Int128Type; + type TrueDivPolarsType = Float64Type; } impl NumericNative for f32 { type PolarsType = Float32Type; + type TrueDivPolarsType = Float32Type; } impl NumericNative for f64 { type PolarsType = Float64Type; + type TrueDivPolarsType = Float64Type; } diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index dfaf4f1375cb..1a1ab80f342a 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -25,7 +25,10 @@ use num_traits::{Num, NumCast}; use crate::config::*; use crate::prelude::*; -const LIMIT: usize = 25; + +// Note: see https://github.com/pola-rs/polars/pull/13699 for the rationale +// behind choosing 10 as the default value for default number of rows displayed +const LIMIT: usize = 10; #[derive(Copy, Clone)] #[repr(u8)] @@ -130,19 +133,18 @@ macro_rules! format_array { }; Ok(()) }; - if (limit == 0 && $a.len() > 0) || ($a.len() > limit + 1) { - if limit > 0 { - for i in 0..std::cmp::max((limit / 2), 1) { - let v = $a.get_any_value(i).unwrap(); - write_fn(v, $f)?; - } + if $a.len() > limit { + let half = limit / 2; + let rest = limit % 2; + + for i in 0..(half + rest) { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; } write!($f, "\t…\n")?; - if limit > 1 { - for i in ($a.len() - (limit + 1) / 2)..$a.len() { - let v = $a.get_any_value(i).unwrap(); - write_fn(v, $f)?; - } + for i in ($a.len() - half)..$a.len() { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; } } else { for i in 0..$a.len() { @@ -163,7 +165,7 @@ fn format_object_array( array_type: &str, ) -> fmt::Result { match object.dtype() { - DataType::Object(inner_type, None) => { + DataType::Object(inner_type, _) => { let limit = std::cmp::min(LIMIT, object.len()); write!( f, @@ -333,7 +335,7 @@ impl Debug for Series { format_array!(f, self.list().unwrap(), &dt, self.name(), "Series") }, #[cfg(feature = "object")] - DataType::Object(_, None) => format_object_array(f, self, self.name(), "Series"), + DataType::Object(_, _) => format_object_array(f, self, self.name(), "Series"), #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, _) => { format_array!(f, self.categorical().unwrap(), "cat", self.name(), "Series") @@ -524,9 +526,7 @@ impl Display for DataFrame { .as_deref() .unwrap_or("") .parse() - // Note: see "https://github.com/pola-rs/polars/pull/13699" for - // the rationale behind choosing 10 as the default value ;) - .map_or(10, |n: i64| if n < 0 { height } else { n as usize }); + .map_or(LIMIT, |n: i64| if n < 0 { height } else { n as usize }); let (n_first, n_last) = if self.width() > max_n_cols { ((max_n_cols + 1) / 2, max_n_cols / 2) @@ -588,11 +588,15 @@ impl Display for DataFrame { let mut max_elem_lengths: Vec = vec![0; n_tbl_cols]; if max_n_rows > 0 { - if height > max_n_rows + 1 { - // Truncate the table if we have more rows than the configured maximum - // number of rows plus the single row which would contain "…". + if height > max_n_rows { + // Truncate the table if we have more rows than the + // configured maximum number of rows let mut rows = Vec::with_capacity(std::cmp::max(max_n_rows, 2)); - for i in 0..std::cmp::max(max_n_rows / 2, 1) { + + let half = max_n_rows / 2; + let rest = max_n_rows % 2; + + for i in 0..(half + rest) { let row = self .columns .iter() @@ -606,23 +610,16 @@ impl Display for DataFrame { } let dots = rows[0].iter().map(|_| "…".to_string()).collect(); rows.push(dots); - if max_n_rows > 1 { - for i in (height - (max_n_rows + 1) / 2)..height { - let row = self - .columns - .iter() - .map(|s| s.str_value(i).unwrap()) - .collect(); + for i in (height - half)..height { + let row = self + .columns + .iter() + .map(|s| s.str_value(i).unwrap()) + .collect(); - let row_strings = prepare_row( - row, - n_first, - n_last, - str_truncate, - &mut max_elem_lengths, - ); - rows.push(row_strings); - } + let row_strings = + prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); + rows.push(row_strings); } table.add_rows(rows); } else { @@ -991,10 +988,10 @@ impl Display for AnyValue<'_> { let width = 0; match self { AnyValue::Null => write!(f, "null"), - AnyValue::UInt8(v) => write!(f, "{v}"), - AnyValue::UInt16(v) => write!(f, "{v}"), - AnyValue::UInt32(v) => write!(f, "{v}"), - AnyValue::UInt64(v) => write!(f, "{v}"), + AnyValue::UInt8(v) => fmt_integer(f, width, *v), + AnyValue::UInt16(v) => fmt_integer(f, width, *v), + AnyValue::UInt32(v) => fmt_integer(f, width, *v), + AnyValue::UInt64(v) => fmt_integer(f, width, *v), AnyValue::Int8(v) => fmt_integer(f, width, *v), AnyValue::Int16(v) => fmt_integer(f, width, *v), AnyValue::Int32(v) => fmt_integer(f, width, *v), diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs index be60fb04346f..0082ecc4534a 100644 --- a/crates/polars-core/src/frame/arithmetic.rs +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -21,7 +21,7 @@ macro_rules! impl_arithmetic { let cols = POOL.install(|| {$self.columns.par_iter().map(|s| { Ok(&s.cast(&st)? $operand &rhs) }).collect::>()})?; - Ok(DataFrame::new_no_checks(cols)) + Ok(unsafe { DataFrame::new_no_checks(cols) }) }} } @@ -113,7 +113,7 @@ impl DataFrame { ) -> PolarsResult { let max_len = std::cmp::max(self.height(), other.height()); let max_width = std::cmp::max(self.width(), other.width()); - let mut cols = self + let cols = self .get_columns() .par_iter() .zip(other.get_columns().par_iter()) @@ -133,8 +133,8 @@ impl DataFrame { }; f(&l, &r) - }) - .collect::>>()?; + }); + let mut cols = POOL.install(|| cols.collect::>>())?; let col_len = cols.len(); if col_len < max_width { diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index b75e1efbe33a..51fe294aa3fa 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -96,7 +96,7 @@ impl DataFrame { let mut row_idx = IdxCa::from_vec("", row_idx); row_idx.set_sorted_flag(IsSorted::Ascending); - // Safety + // SAFETY: // We just created indices that are in bounds. let mut df = unsafe { df.take_unchecked(&row_idx) }; process_column(self, &mut df, exploded.clone())?; @@ -259,13 +259,25 @@ impl DataFrame { let id_vars = args.id_vars; let mut value_vars = args.value_vars; - let value_name = args.value_name.as_deref().unwrap_or("value"); let variable_name = args.variable_name.as_deref().unwrap_or("variable"); + let value_name = args.value_name.as_deref().unwrap_or("value"); let len = self.height(); // if value vars is empty we take all columns that are not in id_vars. if value_vars.is_empty() { + // return empty frame if there are no columns available to use as value vars + if id_vars.len() == self.width() { + let variable_col = Series::new_empty(variable_name, &DataType::String); + let value_col = Series::new_empty(variable_name, &DataType::Null); + + let mut out = self.select(id_vars).unwrap().clear().columns; + out.push(variable_col); + out.push(value_col); + + return Ok(unsafe { DataFrame::new_no_checks(out) }); + } + let id_vars_set = PlHashSet::from_iter(id_vars.iter().map(|s| s.as_str())); value_vars = self .get_columns() @@ -318,14 +330,14 @@ impl DataFrame { values.extend_from_slice(value_col.chunks()) } let values_arr = concatenate_owned_unchecked(&values)?; - // Safety + // SAFETY: // The give dtype is correct let values = unsafe { Series::from_chunks_and_dtype_unchecked(value_name, vec![values_arr], &st) }; let variable_col = variable_col.as_box(); - // Safety - // The give dtype is correct + // SAFETY: + // The given dtype is correct let variables = unsafe { Series::from_chunks_and_dtype_unchecked( variable_name, @@ -342,7 +354,6 @@ impl DataFrame { #[cfg(test)] mod test { - use crate::frame::explode::MeltArgs; use crate::prelude::*; #[test] diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index 79fe26083a46..72172ec7e736 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -1,5 +1,3 @@ -use arrow::array::StructArray; - use crate::prelude::*; impl TryFrom for DataFrame { @@ -15,7 +13,7 @@ impl TryFrom for DataFrame { .iter() .zip(arrs) .map(|(fld, arr)| { - // Safety + // SAFETY: // reported data type is correct unsafe { Series::_try_from_arrow_unchecked_with_md( @@ -37,7 +35,7 @@ impl From<&Schema> for DataFrame { .iter() .map(|(name, dtype)| Series::new_empty(name, dtype)) .collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } @@ -48,6 +46,6 @@ impl From<&ArrowSchema> for DataFrame { .iter() .map(|fld| Series::new_empty(fld.name.as_str(), &(fld.data_type().into()))) .collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs index ade50c3d5899..99fc3393fc2b 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -40,14 +40,14 @@ where } length_so_far += idx_len as i64; - // Safety: + // SAFETY: // group tuples are in bounds { list_values.extend(idx.iter().map(|idx| { debug_assert!((*idx as usize) < values.len()); *values.get_unchecked(*idx as usize) })); - // Safety: + // SAFETY: // we know that offsets has allocated enough slots offsets.push_unchecked(length_so_far); } @@ -77,7 +77,7 @@ where validity, ); let data_type = ListArray::::default_datatype(T::get_dtype().to_arrow(true)); - // Safety: + // SAFETY: // offsets are monotonically increasing let arr = ListArray::::new( data_type, @@ -110,7 +110,7 @@ where length_so_far += len as i64; list_values.extend_from_slice(&values[first as usize..(first + len) as usize]); { - // Safety: + // SAFETY: // we know that offsets has allocated enough slots offsets.push_unchecked(length_so_far); } diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index d930f88f9f18..82f661dc0752 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -114,14 +114,14 @@ impl Series { Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups), dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups), #[cfg(feature = "dtype-datetime")] - dt @ Datetime(_, _) => self + dt @ (Datetime(_, _) | Duration(_)) => self .to_physical_repr() .agg_median(groups) .cast(&Int64) .unwrap() .cast(dt) .unwrap(), - dt @ (Date | Duration(_) | Time) => { + dt @ (Date | Time) => { let ca = self.to_physical_repr(); let physical_type = ca.dtype(); let s = apply_method_physical_integer!(ca, agg_median, groups); @@ -172,14 +172,14 @@ impl Series { Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups), dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_mean, groups), #[cfg(feature = "dtype-datetime")] - dt @ Datetime(_, _) => self + dt @ (Datetime(_, _) | Duration(_)) => self .to_physical_repr() .agg_mean(groups) .cast(&Int64) .unwrap() .cast(dt) .unwrap(), - dt @ (Date | Duration(_) | Time) => { + dt @ (Date | Time) => { let ca = self.to_physical_repr(); let physical_type = ca.dtype(); let s = apply_method_physical_integer!(ca, agg_mean, groups); diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 68c350dadc91..704408ebed64 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -12,7 +12,6 @@ use arrow::legacy::kernels::rolling::no_nulls::{ MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow, }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; -use arrow::legacy::kernels::rolling::{RollingQuantileParams, RollingVarParams}; use arrow::legacy::kernels::take_agg::*; use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::legacy::trusted_len::TrustedLenPush; @@ -78,7 +77,7 @@ where // these represent the number of groups in the group_by operation let output_len = offsets.size_hint().0; // start with a dummy index, will be overwritten on first iteration. - // Safety: + // SAFETY: // we are in bounds let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params) }; @@ -90,7 +89,7 @@ where .map(|(idx, (start, len))| { let end = start + len; - // safety: + // SAFETY: // we are in bounds let agg = if start == end { @@ -102,7 +101,7 @@ where match agg { Some(val) => val, None => { - // safety: we are in bounds + // SAFETY: we are in bounds unsafe { validity.set_unchecked(idx, false) }; T::default() }, diff --git a/crates/polars-core/src/frame/group_by/aggregations/string.rs b/crates/polars-core/src/frame/group_by/aggregations/string.rs index 889217addd90..4891852e1dab 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/string.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/string.rs @@ -69,7 +69,7 @@ impl BinaryChunked { let arr_group = _slice_from_offsets(self, first, len); let borrowed = arr_group.min_binary(); - // Safety: + // SAFETY: // The borrowed has `arr_group`s lifetime, but it actually points to data // hold by self. Here we tell the compiler that. unsafe { std::mem::transmute::, Option<&'a [u8]>>(borrowed) } @@ -131,7 +131,7 @@ impl BinaryChunked { let arr_group = _slice_from_offsets(self, first, len); let borrowed = arr_group.max_binary(); - // Safety: + // SAFETY: // The borrowed has `arr_group`s lifetime, but it actually points to data // hold by self. Here we tell the compiler that. unsafe { std::mem::transmute::, Option<&'a [u8]>>(borrowed) } diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 9ad0ab215a48..820623d2c4ba 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -4,20 +4,16 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; use hashbrown::HashMap; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; -use polars_utils::idxvec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; +use polars_utils::unitvec; use rayon::prelude::*; -use super::GroupsProxy; -use crate::datatypes::PlHashMap; -use crate::frame::group_by::{GroupsIdx, IdxItem}; -use crate::hashing::{ - _df_rows_to_hashes_threaded_vertical, series_to_hashes, IdBuildHasher, IdxHash, *, -}; +use crate::hashing::*; use crate::prelude::compare_inner::TotalEqInner; use crate::prelude::*; -use crate::utils::{flatten, split_df, CustomIterTools}; +use crate::utils::{flatten, split_df}; use crate::POOL; fn get_init_size() -> usize { @@ -144,19 +140,22 @@ fn finish_group_order_vecs( pub(crate) fn group_by(a: impl Iterator, sorted: bool) -> GroupsProxy where - T: Hash + Eq, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let init_size = get_init_size(); - let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); + let mut hash_tbl: PlHashMap = + PlHashMap::with_capacity(init_size); let mut cnt = 0; a.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt; cnt += 1; let entry = hash_tbl.entry(k); match entry { Entry::Vacant(entry) => { - let tuples = idxvec![idx]; + let tuples = unitvec![idx]; entry.insert((idx, tuples)); }, Entry::Occupied(mut entry) => { @@ -188,7 +187,8 @@ pub(crate) fn group_by_threaded_slice( sorted: bool, ) -> GroupsProxy where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, IntoSlice: AsRef<[T]> + Send + Sync, { let init_size = get_init_size(); @@ -200,7 +200,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -211,17 +211,18 @@ where let mut cnt = 0; keys.iter().for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) { let hash = hasher.hash_one(k); - let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, k); + let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k); match entry { RawEntryMut::Vacant(entry) => { - let tuples = idxvec![idx]; - entry.insert_with_hasher(hash, *k, (idx, tuples), |k| { + let tuples = unitvec![idx]; + entry.insert_with_hasher(hash, k, (idx, tuples), |k| { hasher.hash_one(k) }); }, @@ -252,7 +253,8 @@ pub(crate) fn group_by_threaded_iter( where I: IntoIterator + Send + Sync + Clone, I::IntoIter: ExactSizeIterator, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, { let init_size = get_init_size(); @@ -263,7 +265,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -274,6 +276,7 @@ where let mut cnt = 0; keys.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; @@ -283,7 +286,7 @@ where match entry { RawEntryMut::Vacant(entry) => { - let tuples = idxvec![idx]; + let tuples = unitvec![idx]; entry.insert_with_hasher(hash, k, (idx, tuples), |k| { hasher.hash_one(k) }); @@ -362,7 +365,7 @@ pub(crate) fn populate_multiple_key_hashmap2<'a, V, H, F, G>( // cache misses original_h == idx_hash.hash && { let key_idx = idx_hash.idx; - // Safety: + // SAFETY: // indices in a group_by operation are always in bounds. unsafe { compare_keys(keys_cmp, key_idx as usize, idx as usize) } } @@ -438,7 +441,7 @@ pub(crate) fn group_by_threaded_multiple_keys_flat( let all_vals = &mut *(all_buf_ptr as *mut Vec); let offset_idx = first_vals.len() as IdxSize; - let tuples = idxvec![row_idx]; + let tuples = unitvec![row_idx]; all_vals.push(tuples); first_vals.push(row_idx); offset_idx @@ -501,7 +504,7 @@ pub(crate) fn group_by_multiple_keys(keys: DataFrame, sorted: bool) -> PolarsRes let all_vals = &mut *(all_buf_ptr as *mut Vec); let offset_idx = first_vals.len() as IdxSize; - let tuples = idxvec![row_idx]; + let tuples = unitvec![row_idx]; all_vals.push(tuples); first_vals.push(row_idx); offset_idx diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 067844a1dafb..6dfc547da6b5 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -2,6 +2,7 @@ use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups}; use arrow::legacy::prelude::*; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::config::verbose; @@ -25,9 +26,9 @@ fn group_multithreaded(ca: &ChunkedArray) -> bool { fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsProxy where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash, { if multithreaded && group_multithreaded(ca) { let n_partitions = _set_partition_size(); @@ -49,7 +50,7 @@ where } else if !ca.has_validity() { group_by(ca.into_no_null_iter(), sorted) } else { - group_by(ca.into_iter(), sorted) + group_by(ca.iter(), sorted) } } @@ -93,35 +94,31 @@ where let n_parts = parts.len(); let first_ptr = &values[0] as *const T::Native as usize; - let groups = POOL - .install(|| { - parts.par_iter().enumerate().map(|(i, part)| { - // we go via usize as *const is not send - let first_ptr = first_ptr as *const T::Native; - - let part_first_ptr = &part[0] as *const T::Native; - let mut offset = - unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize; - - // nulls first: only add the nulls at the first partition - if nulls_first && i == 0 { - partition_to_groups(part, null_count as IdxSize, true, offset) - } - // nulls last: only compute at the last partition - else if !nulls_first && i == n_parts - 1 { - partition_to_groups(part, null_count as IdxSize, false, offset) - } - // other partitions - else { - if nulls_first { - offset += null_count as IdxSize; - }; + let groups = parts.par_iter().enumerate().map(|(i, part)| { + // we go via usize as *const is not send + let first_ptr = first_ptr as *const T::Native; - partition_to_groups(part, 0, false, offset) - } - }) - }) - .collect::>(); + let part_first_ptr = &part[0] as *const T::Native; + let mut offset = unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize; + + // nulls first: only add the nulls at the first partition + if nulls_first && i == 0 { + partition_to_groups(part, null_count as IdxSize, true, offset) + } + // nulls last: only compute at the last partition + else if !nulls_first && i == n_parts - 1 { + partition_to_groups(part, null_count as IdxSize, false, offset) + } + // other partitions + else { + if nulls_first { + offset += null_count as IdxSize; + }; + + partition_to_groups(part, 0, false, offset) + } + }); + let groups = POOL.install(|| groups.collect::>()); flatten_par(&groups) } else { partition_to_groups(values, null_count as IdxSize, nulls_first, 0) @@ -167,14 +164,28 @@ where }; num_groups_proxy(ca, multithreaded, sorted) }, - DataType::Int64 | DataType::Float64 => { + DataType::Int64 => { let ca = self.bit_repr_large(); num_groups_proxy(&ca, multithreaded, sorted) }, - DataType::Int32 | DataType::Float32 => { + DataType::Int32 => { let ca = self.bit_repr_small(); num_groups_proxy(&ca, multithreaded, sorted) }, + DataType::Float64 => { + // convince the compiler that we are this type. + let ca: &Float64Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::Float32 => { + // convince the compiler that we are this type. + let ca: &Float32Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))] DataType::Int8 => { // convince the compiler that we are this type. @@ -271,7 +282,7 @@ impl IntoGroupsProxy for BinaryChunked { let ca = self.slice(offset as i64, len); let byte_hashes = fill_bytes_hashes(&ca, null_h, hb.clone()); - // Safety: + // SAFETY: // the underlying data is tied to self unsafe { std::mem::transmute::>, Vec>>( @@ -327,7 +338,7 @@ impl IntoGroupsProxy for BinaryOffsetChunked { let ca = self.slice(offset as i64, len); let byte_hashes = fill_bytes_offset_hashes(&ca, null_h, hb.clone()); - // Safety: + // SAFETY: // the underlying data is tied to self unsafe { std::mem::transmute::>, Vec>>( @@ -371,7 +382,7 @@ impl IntoGroupsProxy for ListChunked { None => null_h, }; - // Safety: + // SAFETY: // the underlying data is tied to self unsafe { std::mem::transmute::, BytesHash<'a>>(BytesHash::new( diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 39a73e583ce7..75df1e198c50 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -2,7 +2,6 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; use num_traits::NumCast; use polars_utils::hashing::{BytesHash, DirtyHash}; use rayon::prelude::*; @@ -28,28 +27,31 @@ use crate::prelude::sort::arg_sort_multiple::encode_rows_vertical; // This will remove the sorted flag on signed integers fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame { - DataFrame::new_no_checks( - by.iter() - .map(|s| match s.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - s.cast(&DataType::UInt32).unwrap() - }, - _ => { - if s.dtype().to_physical().is_numeric() { - let s = s.to_physical_repr(); - if s.bit_repr_is_large() { - s.bit_repr_large().into_series() - } else { - s.bit_repr_small().into_series() - } + let columns = by + .iter() + .map(|s| match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + s.cast(&DataType::UInt32).unwrap() + }, + _ => { + if s.dtype().to_physical().is_numeric() { + let s = s.to_physical_repr(); + + if s.dtype().is_float() { + s.into_owned().into_series() + } else if s.bit_repr_is_large() { + s.bit_repr_large().into_series() } else { - s.clone() + s.bit_repr_small().into_series() } - }, - }) - .collect(), - ) + } else { + s.clone() + } + }, + }) + .collect(); + unsafe { DataFrame::new_no_checks(columns) } } impl DataFrame { @@ -793,7 +795,7 @@ impl<'df> GroupBy<'df> { new_cols.extend_from_slice(&self.selected_keys); let cols = self.df.select_series(agg)?; new_cols.extend(cols); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) } } else { Ok(self.df.clone()) @@ -811,7 +813,7 @@ impl<'df> GroupBy<'df> { .get_groups() .par_iter() .map(|g| { - // safety + // SAFETY: // groups are in bounds let sub_df = unsafe { take_df(&df, g) }; f(sub_df) @@ -833,7 +835,7 @@ impl<'df> GroupBy<'df> { .get_groups() .iter() .map(|g| { - // safety + // SAFETY: // groups are in bounds let sub_df = unsafe { take_df(&df, g) }; f(sub_df) diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 795e38194051..bc56f80fdd7d 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -1,6 +1,5 @@ use std::fmt::Debug; -use arrow::array::Array; use arrow::legacy::bit_util::round_upto_multiple_of_64; use num_traits::{FromPrimitive, ToPrimitive}; use polars_utils::idx_vec::IdxVec; @@ -78,7 +77,7 @@ where let end = cache_line_offsets[thread_no + 1]; let end = T::Native::from_usize(end).unwrap(); - // safety: we don't alias + // SAFETY: we don't alias let groups = unsafe { std::slice::from_raw_parts_mut(groups_ptr.get(), len) }; let first = unsafe { std::slice::from_raw_parts_mut(first_ptr.get(), len) }; @@ -93,7 +92,7 @@ where unsafe { if buf.len() == 1 { - // safety: we just pushed + // SAFETY: we just pushed let first_value = buf.get_unchecked(0); *first.get_unchecked_release_mut(cat) = *first_value } @@ -113,7 +112,7 @@ where unsafe { if buf.len() == 1 { - // safety: we just pushed + // SAFETY: we just pushed let first_value = buf.get_unchecked(0); *first.get_unchecked_release_mut(cat) = *first_value @@ -198,7 +197,7 @@ impl CategoricalChunked { let mut out = match &**rev_map { RevMapping::Local(cached, _) => { - if self.can_fast_unique() { + if self._can_fast_unique() { if verbose() { eprintln!("grouping categoricals, run perfect hash function"); } diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index 7eae13edeff9..e1988c363712 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -1,7 +1,6 @@ use std::mem::ManuallyDrop; use std::ops::Deref; -use arrow::legacy::utils::CustomIterTools; use polars_utils::idx_vec::IdxVec; use polars_utils::sync::SyncPtr; use rayon::iter::plumbing::UnindexedConsumer; @@ -500,7 +499,7 @@ impl GroupsProxy { } pub fn slice(&self, offset: i64, len: usize) -> SlicedGroups { - // Safety: + // SAFETY: // we create new `Vec`s from the sliced groups. But we wrap them in ManuallyDrop // so that we never call drop on them. // These groups lifetimes are bounded to the `self`. This must remain valid diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 5c414c4c1ffa..487f712bdbbf 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1,6 +1,5 @@ //! DataFrame module. use std::borrow::Cow; -use std::iter::{FromIterator, Iterator}; use std::{mem, ops}; use ahash::AHashSet; @@ -28,8 +27,6 @@ pub use chunks::*; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; -#[cfg(feature = "algorithm_group_by")] -use crate::frame::group_by::GroupsIndicator; #[cfg(feature = "row_hash")] use crate::hashing::_df_rows_to_hashes_threaded_vertical; #[cfg(feature = "zip_with")] @@ -156,12 +153,15 @@ impl DataFrame { } // Reduce monomorphization. - fn apply_columns(&self, func: &(dyn Fn(&Series) -> Series)) -> Vec { + pub fn _apply_columns(&self, func: &(dyn Fn(&Series) -> Series)) -> Vec { self.columns.iter().map(func).collect() } // Reduce monomorphization. - fn apply_columns_par(&self, func: &(dyn Fn(&Series) -> Series + Send + Sync)) -> Vec { + pub fn _apply_columns_par( + &self, + func: &(dyn Fn(&Series) -> Series + Send + Sync), + ) -> Vec { POOL.install(|| self.columns.par_iter().map(func).collect()) } @@ -198,7 +198,7 @@ impl DataFrame { /// Reserve additional slots into the chunks of the series. pub(crate) fn reserve_chunks(&mut self, additional: usize) { for s in &mut self.columns { - // Safety + // SAFETY: // do not modify the data, simply resize. unsafe { s.chunks_mut().reserve(additional) } } @@ -228,7 +228,7 @@ impl DataFrame { }; let series_cols = if S::is_series() { - // Safety: + // SAFETY: // we are guarded by the type system here. #[allow(clippy::transmute_undefined_repr)] let series_cols = unsafe { std::mem::transmute::, Vec>(columns) }; @@ -252,11 +252,9 @@ impl DataFrame { None => first_len = Some(s.len()), } - if names.contains(name) { + if !names.insert(name) { polars_bail!(duplicate = name); } - - names.insert(name); } // we drop early as the brchk thinks the &str borrows are used when calling the drop // of both `series_cols` and `names` @@ -312,7 +310,8 @@ impl DataFrame { /// static EMPTY: DataFrame = DataFrame::empty(); /// ``` pub const fn empty() -> Self { - DataFrame::new_no_checks(Vec::new()) + // SAFETY: An empty dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(Vec::new()) } } /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. @@ -399,22 +398,46 @@ impl DataFrame { /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the `Series`. /// - /// It is advised to use [Series::new](Series::new) in favor of this method. + /// It is advised to use [DataFrame::new] in favor of this method. + /// + /// # Safety /// - /// # Panic /// It is the callers responsibility to uphold the contract of all `Series` - /// having an equal length, if not this may panic down the line. - pub const fn new_no_checks(columns: Vec) -> DataFrame { + /// having an equal length and a unique name, if not this may panic down the line. + pub const unsafe fn new_no_checks(columns: Vec) -> DataFrame { DataFrame { columns } } + /// Create a new `DataFrame` but does not check the length of the `Series`, + /// only check for duplicates. + /// + /// It is advised to use [DataFrame::new] in favor of this method. + /// + /// # Safety + /// + /// It is the callers responsibility to uphold the contract of all `Series` + /// having an equal length, if not this may panic down the line. + pub unsafe fn new_no_length_checks(columns: Vec) -> PolarsResult { + let mut names = PlHashSet::with_capacity(columns.len()); + for column in &columns { + let name = column.name(); + if !names.insert(name) { + polars_bail!(duplicate = name); + } + } + // we drop early as the brchk thinks the &str borrows are used when calling the drop + // of both `columns` and `names` + drop(names); + Ok(DataFrame { columns }) + } + /// Aggregate all chunks to contiguous memory. #[must_use] pub fn agg_chunks(&self) -> Self { // Don't parallelize this. Memory overhead let f = |s: &Series| s.rechunk(); let cols = self.columns.iter().map(f).collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// Shrink the capacity of this DataFrame to fit its length. @@ -438,7 +461,7 @@ impl DataFrame { /// This may lead to more peak memory consumption. pub fn as_single_chunk_par(&mut self) -> &mut Self { if self.columns.iter().any(|s| s.n_chunks() > 1) { - self.columns = self.apply_columns_par(&|s| s.rechunk()); + self.columns = self._apply_columns_par(&|s| s.rechunk()); } self } @@ -522,6 +545,7 @@ impl DataFrame { #[inline] /// Get mutable access to the underlying columns. + /// /// # Safety /// The caller must ensure the length of all [`Series`] remains equal. pub unsafe fn get_columns_mut(&mut self) -> &mut Vec { @@ -1065,7 +1089,7 @@ impl DataFrame { } }); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) } /// Drop columns that are in `names`. @@ -1083,7 +1107,7 @@ impl DataFrame { } }); - DataFrame::new_no_checks(new_cols) + unsafe { DataFrame::new_no_checks(new_cols) } } /// Insert a new column at a given index without checking for duplicates. @@ -1239,7 +1263,7 @@ impl DataFrame { }, None => return None, } - // safety: we just checked bounds + // SAFETY: we just checked bounds unsafe { Some(self.columns.iter().map(|s| s.get_unchecked(idx)).collect()) } } @@ -1431,7 +1455,7 @@ impl DataFrame { pub fn _select_impl_unchecked(&self, cols: &[SmartString]) -> PolarsResult { let selected = self.select_series_impl(cols)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } /// Select with a known schema. @@ -1474,7 +1498,7 @@ impl DataFrame { self.select_check_duplicates(cols)?; } let selected = self.select_series_impl_with_schema(cols, schema)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } /// A non generic implementation to reduce compiler bloat. @@ -1506,7 +1530,7 @@ impl DataFrame { fn select_physical_impl(&self, cols: &[SmartString]) -> PolarsResult { self.select_check_duplicates(cols)?; let selected = self.select_series_physical_impl(cols)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } fn select_check_duplicates(&self, cols: &[SmartString]) -> PolarsResult<()> { @@ -1622,7 +1646,7 @@ impl DataFrame { .iter() .map(|s| s.filter(mask)) .collect::>()?; - Ok(DataFrame::new_no_checks(cols)) + Ok(unsafe { DataFrame::new_no_checks(cols) }) }) .collect() }); @@ -1651,13 +1675,13 @@ impl DataFrame { return self.clone().filter_vertical(mask); } let new_col = self.try_apply_columns_par(&|s| s.filter(mask))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// Same as `filter` but does not parallelize. pub fn _filter_seq(&self, mask: &BooleanChunked) -> PolarsResult { let new_col = self.try_apply_columns(&|s| s.filter(mask))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// Take [`DataFrame`] rows by index values. @@ -1674,7 +1698,7 @@ impl DataFrame { pub fn take(&self, indices: &IdxCa) -> PolarsResult { let new_col = POOL.install(|| self.try_apply_columns_par(&|s| s.take(indices)))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// # Safety @@ -1685,11 +1709,11 @@ impl DataFrame { unsafe fn take_unchecked_impl(&self, idx: &IdxCa, allow_threads: bool) -> Self { let cols = if allow_threads { - POOL.install(|| self.apply_columns_par(&|s| s.take_unchecked(idx))) + POOL.install(|| self._apply_columns_par(&|s| s.take_unchecked(idx))) } else { self.columns.iter().map(|s| s.take_unchecked(idx)).collect() }; - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { @@ -1698,14 +1722,14 @@ impl DataFrame { unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { let cols = if allow_threads { - POOL.install(|| self.apply_columns_par(&|s| s.take_slice_unchecked(idx))) + POOL.install(|| self._apply_columns_par(&|s| s.take_slice_unchecked(idx))) } else { self.columns .iter() .map(|s| s.take_slice_unchecked(idx)) .collect() }; - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// Rename a column in the [`DataFrame`]. @@ -1845,7 +1869,7 @@ impl DataFrame { take = take.slice(offset, len); } - // Safety: + // SAFETY: // the created indices are in bounds let mut df = unsafe { df.take_unchecked_impl(&take, parallel) }; set_sorted(&mut df); @@ -2234,12 +2258,12 @@ impl DataFrame { .iter() .map(|s| s.slice(offset, length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } pub fn clear(&self) -> Self { let col = self.columns.iter().map(|s| s.clear()).collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } #[must_use] @@ -2247,7 +2271,8 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } - DataFrame::new_no_checks(self.apply_columns_par(&|s| s.slice(offset, length))) + let columns = self._apply_columns_par(&|s| s.slice(offset, length)); + unsafe { DataFrame::new_no_checks(columns) } } #[must_use] @@ -2255,11 +2280,12 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } - DataFrame::new_no_checks(self.apply_columns(&|s| { + let columns = self._apply_columns(&|s| { let mut out = s.slice(offset, length); out.shrink_to_fit(); out - })) + }); + unsafe { DataFrame::new_no_checks(columns) } } /// Get the head of the [`DataFrame`]. @@ -2302,7 +2328,7 @@ impl DataFrame { .iter() .map(|s| s.head(length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Get the tail of the [`DataFrame`]. @@ -2342,7 +2368,7 @@ impl DataFrame { .iter() .map(|s| s.tail(length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches. @@ -2382,7 +2408,7 @@ impl DataFrame { #[must_use] pub fn reverse(&self) -> Self { let col = self.columns.iter().map(|s| s.reverse()).collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -2391,9 +2417,9 @@ impl DataFrame { /// See the method on [Series](crate::series::SeriesTrait::shift) for more info on the `shift` operation. #[must_use] pub fn shift(&self, periods: i64) -> Self { - let col = self.apply_columns_par(&|s| s.shift(periods)); + let col = self._apply_columns_par(&|s| s.shift(periods)); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Replace None values with one of the following strategies: @@ -2407,7 +2433,7 @@ impl DataFrame { pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { let col = self.try_apply_columns_par(&|s| s.fill_null(strategy))?; - Ok(DataFrame::new_no_checks(col)) + Ok(unsafe { DataFrame::new_no_checks(col) }) } /// Aggregate the column horizontally to their min values. @@ -2462,24 +2488,25 @@ impl DataFrame { } } - /// Aggregate the column horizontally to their sum values. + /// Sum all values horizontally across columns. pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { let apply_null_strategy = - |s: &Series, null_strategy: NullStrategy| -> PolarsResult { + |s: Series, null_strategy: NullStrategy| -> PolarsResult { if let NullStrategy::Ignore = null_strategy { // if has nulls - if s.has_validity() { + if s.null_count() > 0 { return s.fill_null(FillNullStrategy::Zero); } } - Ok(s.clone()) + Ok(s) }; let sum_fn = - |acc: &Series, s: &Series, null_strategy: NullStrategy| -> PolarsResult { + |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { let acc: Series = apply_null_strategy(acc, null_strategy)?; let s = apply_null_strategy(s, null_strategy)?; - Ok(&acc + &s) + // This will do owned arithmetic and can be mutable + Ok(acc + s) }; let non_null_cols = self @@ -2497,26 +2524,34 @@ impl DataFrame { Ok(Some(self.columns[0].clone())) } }, - 1 => Ok(Some(apply_null_strategy(non_null_cols[0], null_strategy)?)), - 2 => sum_fn(non_null_cols[0], non_null_cols[1], null_strategy).map(Some), + 1 => Ok(Some(apply_null_strategy( + non_null_cols[0].clone(), + null_strategy, + )?)), + 2 => sum_fn( + non_null_cols[0].clone(), + non_null_cols[1].clone(), + null_strategy, + ) + .map(Some), _ => { // the try_reduce_with is a bit slower in parallelism, // but I don't think it matters here as we parallelize over columns, not over elements - POOL.install(|| { + let out = POOL.install(|| { non_null_cols .into_par_iter() - .map(|s| Ok(Cow::Borrowed(s))) - .try_reduce_with(|l, r| sum_fn(&l, &r, null_strategy).map(Cow::Owned)) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 3 columns + .cloned() + .map(Ok) + .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) + // We can unwrap because we started with at least 3 columns, so we always get a Some .unwrap() - .map(|cow| Some(cow.into_owned())) - }) + }); + out.map(Some) }, } } - /// Aggregate the column horizontally to their mean values. + /// Compute the mean of all values horizontally across columns. pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { match self.columns.len() { 0 => Ok(None), @@ -2531,7 +2566,7 @@ impl DataFrame { }) .cloned() .collect(); - let numeric_df = DataFrame::new_no_checks(columns); + let numeric_df = unsafe { DataFrame::new_no_checks(columns) }; let sum = || numeric_df.sum_horizontal(null_strategy); @@ -2665,7 +2700,7 @@ impl DataFrame { let groups = gb.get_groups(); let (offset, len) = slice.unwrap_or((0, groups.len())); let groups = groups.slice(offset, len); - df.apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) + df._apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) }, (UniqueKeepStrategy::Last, true) => { // maintain order by last values, so the sorted groups are not correct as they @@ -2694,14 +2729,14 @@ impl DataFrame { let groups = gb.get_groups(); let (offset, len) = slice.unwrap_or((0, groups.len())); let groups = groups.slice(offset, len); - df.apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) + df._apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) }, (UniqueKeepStrategy::Last, false) => { let gb = df.group_by(names)?; let groups = gb.get_groups(); let (offset, len) = slice.unwrap_or((0, groups.len())); let groups = groups.slice(offset, len); - df.apply_columns_par(&|s| unsafe { s.agg_last(&groups) }) + df._apply_columns_par(&|s| unsafe { s.agg_last(&groups) }) }, (UniqueKeepStrategy::None, _) => { let df_part = df.select(names)?; @@ -2713,7 +2748,7 @@ impl DataFrame { return df.filter(&mask); }, }; - Ok(DataFrame::new_no_checks(columns)) + Ok(unsafe { DataFrame::new_no_checks(columns) }) } /// Get a mask of all the unique rows in the [`DataFrame`]. @@ -2774,7 +2809,7 @@ impl DataFrame { .iter() .map(|s| Series::new(s.name(), &[s.null_count() as IdxSize])) .collect(); - Self::new_no_checks(cols) + unsafe { Self::new_no_checks(cols) } } /// Hash and combine the row values @@ -2802,56 +2837,6 @@ impl DataFrame { .reduce(|acc, b| try_get_supertype(&acc?, &b.unwrap())) } - #[cfg(feature = "chunked_ids")] - #[doc(hidden)] - /// Take elements by a slice of [`ChunkId`]s. - /// # Safety - /// Does not do any bound checks. - /// `sorted` indicates if the chunks are sorted. - #[doc(hidden)] - pub unsafe fn _take_chunked_unchecked_seq(&self, idx: &[ChunkId], sorted: IsSorted) -> Self { - let cols = self.apply_columns(&|s| s._take_chunked_unchecked(idx, sorted)); - - DataFrame::new_no_checks(cols) - } - #[cfg(feature = "chunked_ids")] - /// Take elements by a slice of optional [`ChunkId`]s. - /// # Safety - /// Does not do any bound checks. - #[doc(hidden)] - pub unsafe fn _take_opt_chunked_unchecked_seq(&self, idx: &[Option]) -> Self { - let cols = self.apply_columns(&|s| match s.dtype() { - DataType::String => s._take_opt_chunked_unchecked_threaded(idx, true), - _ => s._take_opt_chunked_unchecked(idx), - }); - - DataFrame::new_no_checks(cols) - } - - #[cfg(feature = "chunked_ids")] - /// # Safety - /// Doesn't perform any bound checks - pub unsafe fn _take_chunked_unchecked(&self, idx: &[ChunkId], sorted: IsSorted) -> Self { - let cols = self.apply_columns_par(&|s| match s.dtype() { - DataType::String => s._take_chunked_unchecked_threaded(idx, sorted, true), - _ => s._take_chunked_unchecked(idx, sorted), - }); - - DataFrame::new_no_checks(cols) - } - - #[cfg(feature = "chunked_ids")] - /// # Safety - /// Doesn't perform any bound checks - pub unsafe fn _take_opt_chunked_unchecked(&self, idx: &[Option]) -> Self { - let cols = self.apply_columns_par(&|s| match s.dtype() { - DataType::String => s._take_opt_chunked_unchecked_threaded(idx, true), - _ => s._take_opt_chunked_unchecked(idx), - }); - - DataFrame::new_no_checks(cols) - } - /// Take by index values given by the slice `idx`. /// # Warning /// Be careful with allowing threads when calling this in a large hot loop @@ -3053,7 +3038,7 @@ impl Iterator for PhysRecordBatchIter<'_> { impl Default for DataFrame { fn default() -> Self { - DataFrame::new_no_checks(vec![]) + DataFrame::empty() } } @@ -3076,7 +3061,6 @@ fn ensure_can_extend(left: &Series, right: &Series) -> PolarsResult<()> { #[cfg(test)] mod test { use super::*; - use crate::frame::NullStrategy; fn create_frame() -> DataFrame { let s0 = Series::new("days", [0, 1, 2].as_ref()); diff --git a/crates/polars-core/src/frame/row/av_buffer.rs b/crates/polars-core/src/frame/row/av_buffer.rs index 5c5d737db096..4a2f7ebfe1ff 100644 --- a/crates/polars-core/src/frame/row/av_buffer.rs +++ b/crates/polars-core/src/frame/row/av_buffer.rs @@ -120,7 +120,7 @@ impl<'a> AnyValueBuffer<'a> { #[cfg(feature = "dtype-time")] (Time(builder), val) if val.is_numeric() => builder.append_value(val.extract()?), (Null(builder), AnyValue::Null) => builder.append_null(), - // Struct and List can be recursive so use anyvalues for that + // Struct and List can be recursive so use AnyValues for that (All(_, vals), v) => vals.push(v), // dynamic types @@ -299,7 +299,7 @@ impl From<(&DataType, usize)> for AnyValueBuffer<'_> { Float64 => AnyValueBuffer::Float64(PrimitiveChunkedBuilder::new("", len)), String => AnyValueBuffer::String(StringChunkedBuilder::new("", len)), Null => AnyValueBuffer::Null(NullChunkedBuilder::new("", 0)), - // Struct and List can be recursive so use anyvalues for that + // Struct and List can be recursive so use AnyValues for that dt => AnyValueBuffer::All(dt.clone(), Vec::with_capacity(len)), } } @@ -456,7 +456,7 @@ impl<'a> AnyValueBufferTrusted<'a> { /// Will add the [`AnyValue`] into [`Self`] and unpack as the physical type /// belonging to [`Self`]. This should only be used with physical buffers /// - /// If a type is not primitive or String, the anyvalue will be converted to static + /// If a type is not primitive or String, the AnyValues will be converted to static /// /// # Safety /// The caller must ensure that the [`AnyValue`] type exactly matches the `Buffer` type and is owned. @@ -668,7 +668,7 @@ impl From<(&DataType, usize)> for AnyValueBufferTrusted<'_> { .collect::>(); AnyValueBufferTrusted::Struct(buffers) }, - // List can be recursive so use anyvalues for that + // List can be recursive so use AnyValues for that dt => AnyValueBufferTrusted::All(dt.clone(), Vec::with_capacity(len)), } } diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs index 1aa2197d1ac5..266677b14def 100644 --- a/crates/polars-core/src/frame/row/dataframe.rs +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -1,5 +1,4 @@ use super::*; -use crate::frame::row::av_buffer::AnyValueBuffer; impl DataFrame { /// Get a row from a [`DataFrame`]. Use of this is discouraged as it will likely be slow. diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index b0c8fcd30f69..7e899fbb2660 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -83,16 +83,6 @@ pub fn coerce_data_type>(datatypes: &[A]) -> DataType { try_get_supertype(lhs, rhs).unwrap_or(String) } -fn is_nested_null(av: &AnyValue) -> bool { - match av { - AnyValue::Null => true, - AnyValue::List(s) => s.null_count() == s.len(), - #[cfg(feature = "dtype-struct")] - AnyValue::Struct(_, _, _) => av._iter_struct_av().all(|av| is_nested_null(&av)), - _ => false, - } -} - pub fn any_values_to_dtype(column: &[AnyValue]) -> PolarsResult<(DataType, usize)> { // we need an index-map as the order of dtypes influences how the // struct fields are constructed. @@ -173,7 +163,7 @@ pub fn rows_to_schema_first_non_null(rows: &[Row], infer_schema_length: Option, new_col_names: Option>>, ) -> PolarsResult { + // We must iterate columns as [`AnyValue`], so we must be contiguous. + self.as_single_chunk_par(); + let mut df = Cow::Borrowed(self); // Can't use self because we might drop a name column let names_out = match new_col_names { None => (0..self.height()).map(|i| format!("column_{i}")).collect(), @@ -186,12 +189,12 @@ where let s = s.cast(&T::get_dtype()).unwrap(); let ca = s.unpack::().unwrap(); - // Safety + // SAFETY: // we access in parallel, but every access is unique, so we don't break aliasing rules // we also ensured we allocated enough memory, so we never reallocate and thus // the pointers remain valid. if has_nulls { - for (col_idx, opt_v) in ca.into_iter().enumerate() { + for (col_idx, opt_v) in ca.iter().enumerate() { match opt_v { None => unsafe { let column = (*(validity_buf_ptr as *mut Vec>)) @@ -221,37 +224,36 @@ where }) }); - cols_t.par_extend(POOL.install(|| { - values_buf - .into_par_iter() - .zip(validity_buf) - .zip(names_out) - .map(|((mut values, validity), name)| { - // Safety: - // all values are written we can now set len - unsafe { - values.set_len(new_height); - } + let par_iter = values_buf + .into_par_iter() + .zip(validity_buf) + .zip(names_out) + .map(|((mut values, validity), name)| { + // SAFETY: + // all values are written we can now set len + unsafe { + values.set_len(new_height); + } - let validity = if has_nulls { - let validity = Bitmap::from_trusted_len_iter(validity.iter().copied()); - if validity.unset_bits() > 0 { - Some(validity) - } else { - None - } + let validity = if has_nulls { + let validity = Bitmap::from_trusted_len_iter(validity.iter().copied()); + if validity.unset_bits() > 0 { + Some(validity) } else { None - }; + } + } else { + None + }; - let arr = PrimitiveArray::::new( - T::get_dtype().to_arrow(true), - values.into(), - validity, - ); - ChunkedArray::with_chunk(name.as_str(), arr).into_series() - }) - })); + let arr = PrimitiveArray::::new( + T::get_dtype().to_arrow(true), + values.into(), + validity, + ); + ChunkedArray::with_chunk(name.as_str(), arr).into_series() + }); + POOL.install(|| cols_t.par_extend(par_iter)); } #[cfg(test)] @@ -260,7 +262,7 @@ mod test { #[test] fn test_transpose() -> PolarsResult<()> { - let df = df![ + let mut df = df![ "a" => [1, 2, 3], "b" => [10, 20, 30], ]?; @@ -274,7 +276,7 @@ mod test { ]?; assert!(out.equals_missing(&expected)); - let df = df![ + let mut df = df![ "a" => [Some(1), None, Some(3)], "b" => [Some(10), Some(20), None], ]?; @@ -287,7 +289,7 @@ mod test { ]?; assert!(out.equals_missing(&expected)); - let df = df![ + let mut df = df![ "a" => ["a", "b", "c"], "b" => [Some(10), Some(20), None], ]?; diff --git a/crates/polars-core/src/frame/top_k.rs b/crates/polars-core/src/frame/top_k.rs index b72116821dc9..efd76d0dda69 100644 --- a/crates/polars-core/src/frame/top_k.rs +++ b/crates/polars-core/src/frame/top_k.rs @@ -1,12 +1,9 @@ use std::cmp::Ordering; -use polars_error::PolarsResult; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::IdxSize; use smartstring::alias::String as SmartString; -use crate::datatypes::IdxCa; -use crate::frame::DataFrame; use crate::prelude::sort::_broadcast_descending; use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded; use crate::prelude::*; diff --git a/crates/polars-core/src/frame/upstream_traits.rs b/crates/polars-core/src/frame/upstream_traits.rs index 21f5a0e74f84..e2f28aefdb33 100644 --- a/crates/polars-core/src/frame/upstream_traits.rs +++ b/crates/polars-core/src/frame/upstream_traits.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use crate::prelude::*; diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs index 6c802e02656c..06fcea661219 100644 --- a/crates/polars-core/src/functions.rs +++ b/crates/polars-core/src/functions.rs @@ -75,7 +75,7 @@ pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { None => columns.push(Series::full_null(name, height, dtype)), } } - DataFrame::new_no_checks(columns) + unsafe { DataFrame::new_no_checks(columns) } }) .collect::>(); diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs index 6abb9c7097cc..5e1b891a9702 100644 --- a/crates/polars-core/src/hashing/mod.rs +++ b/crates/polars-core/src/hashing/mod.rs @@ -74,7 +74,7 @@ pub fn populate_multiple_key_hashmap( // before we incur a cache miss idx_hash.hash == original_h && { let key_idx = idx_hash.idx; - // Safety: + // SAFETY: // indices in a group_by operation are always in bounds. unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } } diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index db4a8fab1005..7f2edbfa02c4 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,13 +1,13 @@ use arrow::bitmap::utils::get_bit_unchecked; #[cfg(feature = "group_by_list")] use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; use super::*; -use crate::datatypes::UInt64Chunked; use crate::prelude::*; -use crate::utils::arrow::array::Array; +use crate::series::implementations::null::NullChunked; use crate::POOL; // See: https://github.com/tkaitchuck/aHash/blob/f9acd508bd89e7c5b2877a9510098100f9018d64/src/operations.rs#L4 @@ -66,10 +66,11 @@ fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Ve }); } -fn integer_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) +fn numeric_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -88,16 +89,17 @@ where .as_slice() .iter() .copied() - .map(|v| random_state.hash_one(v)), + .map(|v| random_state.hash_one(v.to_total_ord())), ); }); insert_null_hash(&ca.chunks, random_state, buf) } -fn integer_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) +fn numeric_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { let null_h = get_null_hash_value(&random_state); @@ -110,7 +112,7 @@ where .iter() .zip(&mut hashes[offset..]) .for_each(|(v, h)| { - *h = folded_multiply(random_state.hash_one(v) ^ *h, MULTIPLE); + *h = folded_multiply(random_state.hash_one(v.to_total_ord()) ^ *h, MULTIPLE); }), _ => { let validity = arr.validity().unwrap(); @@ -120,7 +122,7 @@ where .zip(&mut hashes[offset..]) .zip(arr.values().as_slice()) .for_each(|((valid, h), l)| { - let lh = random_state.hash_one(l); + let lh = random_state.hash_one(l.to_total_ord()); let to_hash = [null_h, lh][valid as usize]; // inlined from ahash. This ensures we combine with the previous state @@ -132,11 +134,11 @@ where }); } -macro_rules! vec_hash_int { +macro_rules! vec_hash_numeric { ($ca:ident) => { impl VecHash for $ca { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - integer_vec_hash(self, random_state, buf); + numeric_vec_hash(self, random_state, buf); Ok(()) } @@ -145,21 +147,23 @@ macro_rules! vec_hash_int { random_state: RandomState, hashes: &mut [u64], ) -> PolarsResult<()> { - integer_vec_hash_combine(self, random_state, hashes); + numeric_vec_hash_combine(self, random_state, hashes); Ok(()) } } }; } -vec_hash_int!(Int64Chunked); -vec_hash_int!(Int32Chunked); -vec_hash_int!(Int16Chunked); -vec_hash_int!(Int8Chunked); -vec_hash_int!(UInt64Chunked); -vec_hash_int!(UInt32Chunked); -vec_hash_int!(UInt16Chunked); -vec_hash_int!(UInt8Chunked); +vec_hash_numeric!(Int64Chunked); +vec_hash_numeric!(Int32Chunked); +vec_hash_numeric!(Int16Chunked); +vec_hash_numeric!(Int8Chunked); +vec_hash_numeric!(UInt64Chunked); +vec_hash_numeric!(UInt32Chunked); +vec_hash_numeric!(UInt16Chunked); +vec_hash_numeric!(UInt8Chunked); +vec_hash_numeric!(Float64Chunked); +vec_hash_numeric!(Float32Chunked); impl VecHash for StringChunked { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { @@ -290,6 +294,22 @@ impl VecHash for BinaryOffsetChunked { } } +impl VecHash for NullChunked { + fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + buf.clear(); + buf.resize(self.len(), null_h); + Ok(()) + } + + fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + hashes + .iter_mut() + .for_each(|h| *h = _boost_hash_combine(null_h, *h)); + Ok(()) + } +} impl VecHash for BooleanChunked { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { buf.clear(); @@ -353,30 +373,6 @@ impl VecHash for BooleanChunked { } } -impl VecHash for Float32Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_small().vec_hash(random_state, buf)?; - Ok(()) - } - - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_small() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} -impl VecHash for Float64Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_large().vec_hash(random_state, buf)?; - Ok(()) - } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_large() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} - #[cfg(feature = "group_by_list")] impl VecHash for ListChunked { fn vec_hash(&self, _random_state: RandomState, _buf: &mut Vec) -> PolarsResult<()> { diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index 96677102c23c..a4bfe06ca307 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -8,8 +8,11 @@ pub use arrow::datatypes::{ArrowSchema, Field as ArrowField}; pub use arrow::legacy::kernels::ewm::EWMOptions; pub use arrow::legacy::prelude::*; pub(crate) use arrow::trusted_len::TrustedLen; +#[cfg(feature = "chunked_ids")] +pub(crate) use polars_utils::index::ChunkId; pub(crate) use polars_utils::total_ord::{TotalEq, TotalOrd}; +pub use crate::chunked_array::arithmetic::ArithmeticChunked; pub use crate::chunked_array::builder::{ BinaryChunkedBuilder, BooleanChunkedBuilder, ChunkedBuilder, ListBinaryChunkedBuilder, ListBooleanChunkedBuilder, ListBuilderTrait, ListPrimitiveChunkedBuilder, diff --git a/crates/polars-core/src/random.rs b/crates/polars-core/src/random.rs index 185144063197..f6fd3c1f3978 100644 --- a/crates/polars-core/src/random.rs +++ b/crates/polars-core/src/random.rs @@ -2,7 +2,6 @@ use std::sync::Mutex; use once_cell::sync::Lazy; use rand::prelude::*; -use rand::rngs::SmallRng; static POLARS_GLOBAL_RNG_STATE: Lazy> = Lazy::new(|| Mutex::new(SmallRng::from_entropy())); diff --git a/crates/polars-core/src/serde/chunked_array.rs b/crates/polars-core/src/serde/chunked_array.rs index b39166e9dfed..4da7f1b4f3dc 100644 --- a/crates/polars-core/src/serde/chunked_array.rs +++ b/crates/polars-core/src/serde/chunked_array.rs @@ -59,7 +59,7 @@ where state.serialize_entry("name", name)?; state.serialize_entry("datatype", dtype)?; state.serialize_entry("bit_settings", &bit_settings)?; - state.serialize_entry("values", &IterSer::new(ca.into_iter()))?; + state.serialize_entry("values", &IterSer::new(ca.iter()))?; state.end() } diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index cbd2519dbe52..a40134575c5f 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -234,12 +234,12 @@ impl<'de> Deserialize<'de> for Series { if let Some(s) = value { // we only have one chunk per series as we serialize it in this way. let arr = &s.chunks()[0]; - // safety, we are within bounds + // SAFETY, we are within bounds unsafe { builder.push_unchecked(arr.as_ref(), 0); } } else { - // safety, we are within bounds + // SAFETY, we are within bounds unsafe { builder.push_null(); } diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index d8d27c7bae22..d34bad6512df 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use crate::prelude::*; +use crate::utils::get_supertype; fn any_values_to_primitive(avs: &[AnyValue]) -> ChunkedArray { avs.iter() @@ -8,36 +9,6 @@ fn any_values_to_primitive(avs: &[AnyValue]) -> ChunkedArr .collect_trusted() } -fn any_values_to_string(avs: &[AnyValue], strict: bool) -> PolarsResult { - let mut builder = StringChunkedBuilder::new("", avs.len()); - - // amortize allocations - let mut owned = String::new(); - - for av in avs { - match av { - AnyValue::String(s) => builder.append_value(s), - AnyValue::StringOwned(s) => builder.append_value(s), - AnyValue::Null => builder.append_null(), - AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => { - if strict { - polars_bail!(ComputeError: "mixed dtypes found when building String Series") - } - builder.append_null() - }, - av => { - if strict { - polars_bail!(ComputeError: "mixed dtypes found when building String Series") - } - owned.clear(); - write!(owned, "{av}").unwrap(); - builder.append_value(&owned); - }, - } - } - Ok(builder.finish()) -} - #[cfg(feature = "dtype-decimal")] fn any_values_to_decimal( avs: &[AnyValue], @@ -105,25 +76,6 @@ fn any_values_to_decimal( builder.finish().into_decimal(precision, scale) } -fn any_values_to_binary(avs: &[AnyValue]) -> BinaryChunked { - avs.iter() - .map(|av| match av { - AnyValue::Binary(s) => Some(*s), - AnyValue::BinaryOwned(s) => Some(&**s), - _ => None, - }) - .collect_trusted() -} - -fn any_values_to_bool(avs: &[AnyValue]) -> BooleanChunked { - avs.iter() - .map(|av| match av { - AnyValue::Boolean(b) => Some(*b), - _ => None, - }) - .collect_trusted() -} - #[cfg(feature = "dtype-array")] fn any_values_to_array( avs: &[AnyValue], @@ -155,7 +107,7 @@ fn any_values_to_array( }) .collect_ca_with_dtype("", DataType::Array(Box::new(inner_type.clone()), width)) } - // make sure that wrongly inferred anyvalues don't deviate from the datatype + // make sure that wrongly inferred AnyValues don't deviate from the datatype else { avs.iter() .map(|av| match av { @@ -218,7 +170,7 @@ fn any_values_to_list( }) .collect_trusted() } - // make sure that wrongly inferred anyvalues don't deviate from the datatype + // make sure that wrongly inferred AnyValues don't deviate from the datatype else { avs.iter() .map(|av| match av { @@ -264,6 +216,7 @@ impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { } impl Series { + /// Construct a new [`Series`]` with the given `dtype` from a slice of AnyValues. pub fn from_any_values_and_dtype( name: &str, av: &[AnyValue], @@ -286,8 +239,8 @@ impl Series { DataType::Float32 => any_values_to_primitive::(av).into_series(), DataType::Float64 => any_values_to_primitive::(av).into_series(), DataType::String => any_values_to_string(av, strict)?.into_series(), - DataType::Binary => any_values_to_binary(av).into_series(), - DataType::Boolean => any_values_to_bool(av).into_series(), + DataType::Binary => any_values_to_binary(av, strict)?.into_series(), + DataType::Boolean => any_values_to_bool(av, strict)?.into_series(), #[cfg(feature = "dtype-date")] DataType::Date => any_values_to_primitive::(av) .into_date() @@ -423,7 +376,7 @@ impl Series { }, } }, - DataType::Null => Series::full_null(name, av.len(), &DataType::Null), + DataType::Null => Series::new_null(name, av.len()), #[cfg(feature = "dtype-categorical")] dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { let ca = if let Some(single_av) = av.first() { @@ -449,103 +402,170 @@ impl Series { Ok(s) } - pub fn from_any_values(name: &str, avs: &[AnyValue], strict: bool) -> PolarsResult { - let mut all_flat_null = true; - match avs.iter().find(|av| { - if !matches!(av, AnyValue::Null) { - all_flat_null = false; - } - !av.is_nested_null() - }) { - None => { - if all_flat_null { - Ok(Series::full_null(name, avs.len(), &DataType::Null)) - } else { - // second pass and check for the nested null value that toggled `all_flat_null` to false - // e.g. a list - if let Some(av) = avs.iter().find(|av| !matches!(av, AnyValue::Null)) { - let dtype: DataType = av.into(); - Series::from_any_values_and_dtype(name, avs, &dtype, strict) + /// Construct a new [`Series`] from a slice of AnyValues. + /// + /// The data type of the resulting Series is determined by the `values` + /// and the `strict` parameter: + /// - If `strict` is `true`, the data type is equal to the data type of the + /// first non-null value. If any other non-null values do not match this + /// data type, an error is raised. + /// - If `strict` is `false`, the data type is the supertype of the + /// `values`. **WARNING**: A full pass over the values is required to + /// determine the supertype. Values encountered that do not match the + /// supertype are set to null. + /// - If no values were passed, the resulting data type is `Null`. + pub fn from_any_values(name: &str, values: &[AnyValue], strict: bool) -> PolarsResult { + fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType { + let mut all_flat_null = true; + let first_non_null = values.iter().find(|av| { + if !av.is_null() { + all_flat_null = false + }; + !av.is_nested_null() + }); + match first_non_null { + Some(av) => av.dtype(), + None => { + if all_flat_null { + DataType::Null } else { - unreachable!() + // Second pass to check for the nested null value that + // toggled `all_flat_null` to false, e.g. a List(Null) + let first_nested_null = values.iter().find(|av| !av.is_null()).unwrap(); + first_nested_null.dtype() } - } - }, - Some(av) => { - #[cfg(feature = "dtype-decimal")] - { - if let AnyValue::Decimal(_, _) = av { - let mut s = any_values_to_decimal(avs, None, None)?.into_series(); - s.rename(name); - return Ok(s); + }, + } + } + fn get_any_values_supertype(values: &[AnyValue]) -> DataType { + let mut supertype = DataType::Null; + let mut dtypes = PlHashSet::::new(); + for av in values { + if dtypes.insert(av.dtype()) { + // Values with incompatible data types will be set to null later + if let Some(st) = get_supertype(&supertype, &av.dtype()) { + supertype = st; } } - let dtype: DataType = av.into(); - Series::from_any_values_and_dtype(name, avs, &dtype, strict) - }, + } + supertype } + + let dtype = if strict { + get_first_non_null_dtype(values) + } else { + get_any_values_supertype(values) + }; + Self::from_any_values_and_dtype(name, values, &dtype, strict) } } -impl<'a> From<&AnyValue<'a>> for DataType { - fn from(val: &AnyValue<'a>) -> Self { - use AnyValue::*; - match val { - Null => DataType::Null, - Boolean(_) => DataType::Boolean, - String(_) | StringOwned(_) => DataType::String, - Binary(_) | BinaryOwned(_) => DataType::Binary, - UInt32(_) => DataType::UInt32, - UInt64(_) => DataType::UInt64, - Int32(_) => DataType::Int32, - Int64(_) => DataType::Int64, - Float32(_) => DataType::Float32, - Float64(_) => DataType::Float64, - #[cfg(feature = "dtype-date")] - Date(_) => DataType::Date, - #[cfg(feature = "dtype-datetime")] - Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()), - #[cfg(feature = "dtype-time")] - Time(_) => DataType::Time, - #[cfg(feature = "dtype-array")] - Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size), - List(s) => DataType::List(Box::new(s.dtype().clone())), - #[cfg(feature = "dtype-struct")] - StructOwned(payload) => DataType::Struct(payload.1.to_vec()), - #[cfg(feature = "dtype-struct")] - Struct(_, _, flds) => DataType::Struct(flds.to_vec()), - #[cfg(feature = "dtype-duration")] - Duration(_, tu) => DataType::Duration(*tu), - UInt8(_) => DataType::UInt8, - UInt16(_) => DataType::UInt16, - Int8(_) => DataType::Int8, - Int16(_) => DataType::Int16, - #[cfg(feature = "dtype-categorical")] - Categorical(_, rev_map, arr) => { - if arr.is_null() { - DataType::Categorical(Some(Arc::new((*rev_map).clone())), Default::default()) - } else { - let array = unsafe { arr.deref_unchecked().clone() }; - let rev_map = RevMapping::build_local(array); - DataType::Categorical(Some(Arc::new(rev_map)), Default::default()) - } - }, - #[cfg(feature = "dtype-categorical")] - Enum(_, rev_map, arr) => { - if arr.is_null() { - DataType::Enum(Some(Arc::new((*rev_map).clone())), Default::default()) - } else { - let array = unsafe { arr.deref_unchecked().clone() }; - let rev_map = RevMapping::build_local(array); - DataType::Enum(Some(Arc::new(rev_map)), Default::default()) - } +fn any_values_to_bool(values: &[AnyValue], strict: bool) -> PolarsResult { + if strict { + any_values_to_bool_strict(values) + } else { + Ok(any_values_to_bool_nonstrict(values)) + } +} +fn any_values_to_bool_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = BooleanChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::Boolean(b) => builder.append_value(*b), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Boolean, av)), + } + } + Ok(builder.finish()) +} +fn any_values_to_bool_nonstrict(values: &[AnyValue]) -> BooleanChunked { + let mapper = |av: &AnyValue| match av { + AnyValue::Boolean(b) => Some(*b), + AnyValue::Null => None, + av => match av.cast(&DataType::Boolean) { + AnyValue::Boolean(b) => Some(b), + _ => None, + }, + }; + values.iter().map(mapper).collect_trusted() +} + +fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult { + if strict { + any_values_to_string_strict(values) + } else { + Ok(any_values_to_string_nonstrict(values)) + } +} +fn any_values_to_string_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = StringChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::String, av)), + } + } + Ok(builder.finish()) +} +fn any_values_to_string_nonstrict(values: &[AnyValue]) -> StringChunked { + let mut builder = StringChunkedBuilder::new("", values.len()); + let mut owned = String::new(); // Amortize allocations + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => builder.append_null(), + av => { + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); }, - #[cfg(feature = "object")] - Object(o) => DataType::Object(o.type_name(), None), - #[cfg(feature = "object")] - ObjectOwned(o) => DataType::Object(o.0.type_name(), None), - #[cfg(feature = "dtype-decimal")] - Decimal(_, scale) => DataType::Decimal(None, Some(*scale)), } } + builder.finish() +} + +fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult { + if strict { + any_values_to_binary_strict(values) + } else { + Ok(any_values_to_binary_nonstrict(values)) + } +} +fn any_values_to_binary_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = BinaryChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::Binary(s) => builder.append_value(*s), + AnyValue::BinaryOwned(s) => builder.append_value(&**s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Binary, av)), + } + } + Ok(builder.finish()) +} +fn any_values_to_binary_nonstrict(values: &[AnyValue]) -> BinaryChunked { + values + .iter() + .map(|av| match av { + AnyValue::Binary(b) => Some(*b), + AnyValue::BinaryOwned(b) => Some(&**b), + AnyValue::String(s) => Some(s.as_bytes()), + AnyValue::StringOwned(s) => Some(s.as_bytes()), + _ => None, + }) + .collect_trusted() +} + +fn invalid_value_error(dtype: &DataType, value: &AnyValue) -> PolarsError { + polars_err!( + SchemaMismatch: + "unexpected value while building Series of type {:?}; found value of type {:?}: {}", + dtype, + value.dtype(), + value + ) } diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index ca3b75d02380..22101912b285 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -50,7 +50,7 @@ where ChunkedArray: IntoSeries, { fn subtract(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // There will be UB if a ChunkedArray is alive with the wrong datatype. // we now only create the potentially wrong dtype for a short time. // Note that the physical type correctness is checked! @@ -60,28 +60,28 @@ where Ok(out.into_series()) } fn add_to(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see subtract let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; let out = lhs + rhs; Ok(out.into_series()) } fn multiply(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see subtract let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; let out = lhs * rhs; Ok(out.into_series()) } fn divide(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see subtract let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; let out = lhs / rhs; Ok(out.into_series()) } fn remainder(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see subtract let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; let out = lhs % rhs; @@ -154,7 +154,7 @@ pub mod checked { ChunkedArray: IntoSeries, { fn checked_div(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // There will be UB if a ChunkedArray is alive with the wrong datatype. // we now only create the potentially wrong dtype for a short time. // Note that the physical type correctness is checked! @@ -173,7 +173,7 @@ pub mod checked { impl NumOpsDispatchCheckedInner for Float32Type { fn checked_div(lhs: &Float32Chunked, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see check_div for chunkedarray let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; @@ -194,7 +194,7 @@ pub mod checked { impl NumOpsDispatchCheckedInner for Float64Type { fn checked_div(lhs: &Float64Chunked, rhs: &Series) -> PolarsResult { - // Safety: + // SAFETY: // see check_div let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; @@ -394,16 +394,16 @@ pub fn _struct_arithmetic Series>( match (s_fields.len(), rhs_fields.len()) { (_, 1) => { let rhs = &rhs.fields()[0]; - s.apply_fields(|s| func(s, rhs)).into_series() + s._apply_fields(|s| func(s, rhs)).into_series() }, (1, _) => { let s = &s.fields()[0]; - rhs.apply_fields(|rhs| func(s, rhs)).into_series() + rhs._apply_fields(|rhs| func(s, rhs)).into_series() }, _ => { let mut rhs_iter = rhs.fields().iter(); - s.apply_fields(|s| match rhs_iter.next() { + s._apply_fields(|s| match rhs_iter.next() { Some(rhs) => func(s, rhs), None => s.clone(), }) @@ -622,6 +622,22 @@ where } } +// TODO: remove this, temporary band-aid. +impl Series { + pub fn wrapping_trunc_div_scalar(&self, rhs: T) -> Self { + let s = self.to_physical_repr(); + macro_rules! div { + ($ca:expr) => {{ + let rhs = NumCast::from(rhs).unwrap(); + $ca.wrapping_trunc_div_scalar(rhs).into_series() + }}; + } + + let out = downcast_as_macro_arg_physical!(s, div); + finish_cast(self, out) + } +} + impl Mul for &Series where T: Num + NumCast, @@ -692,21 +708,21 @@ where #[must_use] pub fn lhs_sub(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs - v) + ArithmeticChunked::wrapping_sub_scalar_lhs(lhs, self) } /// Apply lhs / self #[must_use] pub fn lhs_div(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs / v) + ArithmeticChunked::legacy_div_scalar_lhs(lhs, self) } /// Apply lhs % self #[must_use] pub fn lhs_rem(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs % v) + ArithmeticChunked::wrapping_mod_scalar_lhs(lhs, self) } } diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index 837d9f32436f..15c891aef935 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -3,8 +3,6 @@ #[cfg(feature = "dtype-struct")] use std::ops::Deref; -use super::Series; -use crate::apply_method_physical_numeric; use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; use crate::series::nulls::replace_non_null; @@ -45,6 +43,7 @@ macro_rules! impl_compare { let lhs = lhs.to_physical_repr(); let rhs = rhs.to_physical_repr(); let mut out = match lhs.dtype() { + Null => lhs.null().unwrap().$method(rhs.null().unwrap()), Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()), String => lhs.str().unwrap().$method(rhs.str().unwrap()), Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()), @@ -66,6 +65,16 @@ macro_rules! impl_compare { .struct_() .unwrap() .$method(rhs.struct_().unwrap().deref()), + #[cfg(feature = "dtype-decimal")] + Decimal(_, s1) => { + let DataType::Decimal(_, s2) = rhs.dtype() else { + unreachable!() + }; + let scale = s1.max(s2).unwrap(); + let lhs = lhs.decimal().unwrap().to_scale(scale).unwrap(); + let rhs = rhs.decimal().unwrap().to_scale(scale).unwrap(); + lhs.0.$method(&rhs.0) + }, _ => unimplemented!(), }; @@ -97,42 +106,22 @@ impl ChunkCompare<&Series> for Series { /// Create a boolean mask by checking for equality. fn equal(&self, rhs: &Series) -> PolarsResult { - match (self.dtype(), rhs.dtype()) { - (DataType::Null, DataType::Null) => { - Ok(BooleanChunked::full_null(self.name(), self.len())) - }, - _ => impl_compare!(self, rhs, equal), - } + impl_compare!(self, rhs, equal) } /// Create a boolean mask by checking for equality. fn equal_missing(&self, rhs: &Series) -> PolarsResult { - match (self.dtype(), rhs.dtype()) { - (DataType::Null, DataType::Null) => { - Ok(BooleanChunked::full(self.name(), true, self.len())) - }, - _ => impl_compare!(self, rhs, equal_missing), - } + impl_compare!(self, rhs, equal_missing) } /// Create a boolean mask by checking for inequality. fn not_equal(&self, rhs: &Series) -> PolarsResult { - match (self.dtype(), rhs.dtype()) { - (DataType::Null, DataType::Null) => { - Ok(BooleanChunked::full_null(self.name(), self.len())) - }, - _ => impl_compare!(self, rhs, not_equal), - } + impl_compare!(self, rhs, not_equal) } /// Create a boolean mask by checking for inequality. fn not_equal_missing(&self, rhs: &Series) -> PolarsResult { - match (self.dtype(), rhs.dtype()) { - (DataType::Null, DataType::Null) => { - Ok(BooleanChunked::full(self.name(), false, self.len())) - }, - _ => impl_compare!(self, rhs, not_equal_missing), - } + impl_compare!(self, rhs, not_equal_missing) } /// Create a boolean mask by checking if self > rhs. diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index e3ab7e173dda..e8a729ab29d6 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use arrow::compute::cast::cast_unchecked as cast; use arrow::datatypes::Metadata; #[cfg(any(feature = "dtype-struct", feature = "dtype-categorical"))] @@ -113,7 +111,7 @@ impl Series { .as_any() .downcast_ref::() .unwrap(); - // Safety: + // SAFETY: // this is highly unsafe. it will dereference a raw ptr on the heap // make sure the ptr is allocated and from this pid // (the pid is checked before dereference) @@ -343,7 +341,7 @@ impl Series { if let Some(metadata) = md { if metadata.get(DTYPE_ENUM_KEY) == Some(&DTYPE_ENUM_VALUE.into()) { - // Safety + // SAFETY: // the invariants of an Arrow Dictionary guarantee the keys are in bounds return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( UInt32Chunked::with_chunk(name, keys.clone()), @@ -354,7 +352,7 @@ impl Series { .into_series()); } } - // Safety + // SAFETY: // the invariants of an Arrow Dictionary guarantee the keys are in bounds Ok( CategoricalChunked::from_keys_and_values( @@ -373,7 +371,7 @@ impl Series { .as_any() .downcast_ref::() .unwrap(); - // Safety: + // SAFETY: // this is highly unsafe. it will dereference a raw ptr on the heap // make sure the ptr is allocated and from this pid // (the pid is checked before dereference) @@ -706,7 +704,7 @@ impl TryFrom<(&str, Vec)> for Series { let (name, chunks) = name_arr; let data_type = check_types(&chunks)?; - // Safety: + // SAFETY: // dtype is checked unsafe { Series::_try_from_arrow_unchecked(name, chunks, &data_type) } } @@ -729,7 +727,7 @@ impl TryFrom<(&ArrowField, Vec)> for Series { let data_type = check_types(&chunks)?; - // Safety: + // SAFETY: // dtype is checked unsafe { Series::_try_from_arrow_unchecked_with_md( diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index 0fc84e9e49bd..164eeceb8ba7 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -1,11 +1,9 @@ use std::any::Any; use std::borrow::Cow; -use super::{private, IntoSeries, SeriesTrait}; +use super::private; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; -#[cfg(feature = "chunked_ids")] -use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::{AsSinglePtr, Settings}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; @@ -98,16 +96,6 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index ab50db3a86ad..86705e8f9af3 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -1,18 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { @@ -142,16 +132,6 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -253,4 +233,7 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 62d7205cc686..d0a5523c7d8c 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -1,16 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { @@ -103,16 +95,6 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -189,4 +171,7 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index f5d264f06a5c..1aa17d298d3f 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -1,19 +1,8 @@ -use std::borrow::Cow; -use std::ops::{BitAnd, BitOr, BitXor}; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::{AsSinglePtr, ChunkIdIter}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { @@ -168,16 +157,6 @@ impl SeriesTrait for SeriesWrap { self.0.mean() } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -320,4 +299,7 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 52e85e315b56..f9a23c261417 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -1,14 +1,6 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{IntoTotalOrdInner, TotalOrdInner}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; unsafe impl IntoSeries for CategoricalChunked { fn into_series(self) -> Series { @@ -26,7 +18,7 @@ impl SeriesWrap { self.0.get_ordering(), ) }; - if keep_fast_unique && self.0.can_fast_unique() { + if keep_fast_unique && self.0._can_fast_unique() { out.set_fast_unique(true) } out @@ -188,18 +180,6 @@ impl SeriesTrait for SeriesWrap { .map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let cats = self.0.physical().take_chunked_unchecked(by, sorted); - self.finish_with_state(false, cats).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let cats = self.0.physical().take_opt_chunked_unchecked(by); - self.finish_with_state(false, cats).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { self.try_with_state(false, |cats| cats.take(indices)) .map(|ca| ca.into_series()) @@ -300,6 +280,17 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + + fn min_as_series(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_as_series(&self.0)) + } + + fn max_as_series(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_as_series(&self.0)) + } + fn as_any(&self) -> &dyn Any { + &self.0 + } } impl private::PrivateSeriesNumeric for SeriesWrap { diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index 058ef8f8c523..5f5e993dcbbc 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -7,15 +7,7 @@ //! opting for a little more run time cost. We cast to the physical type -> apply the operation and //! (depending on the result) cast back to the original type //! -use std::borrow::Cow; -use std::ops::Deref; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::ops::ToBitRepr; -use crate::chunked_array::AsSinglePtr; +use super::*; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; @@ -38,10 +30,10 @@ macro_rules! impl_dyn_series { fn _dtype(&self) -> &DataType { self.0.dtype() } - fn _get_flags(&self) -> Settings{ + fn _get_flags(&self) -> Settings { self.0.get_flags() } - fn _set_flags(&mut self, flags: Settings){ + fn _set_flags(&mut self, flags: Settings) { self.0.set_flags(flags) } @@ -78,17 +70,17 @@ macro_rules! impl_dyn_series { Ok(()) } - #[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups).$into_logical().into_series() } - #[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups).$into_logical().into_series() } - #[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -104,7 +96,7 @@ macro_rules! impl_dyn_series { let lhs = self.cast(&dt)?; let rhs = rhs.cast(&dt)?; lhs.subtract(&rhs) - } + }, (DataType::Date, DataType::Duration(_)) => ((&self .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) .unwrap()) @@ -132,7 +124,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); } - #[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -143,7 +135,6 @@ macro_rules! impl_dyn_series { } impl SeriesTrait for SeriesWrap<$ca> { - fn rename(&mut self, name: &str) { self.0.rename(name); } @@ -205,18 +196,6 @@ macro_rules! impl_dyn_series { .map(|ca| ca.$into_logical().into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let ca = self.0.deref().take_chunked_unchecked(by, sorted); - ca.$into_logical().into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let ca = self.0.deref().take_opt_chunked_unchecked(by); - ca.$into_logical().into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.$into_logical().into_series()) } @@ -250,7 +229,7 @@ macro_rules! impl_dyn_series { fn cast(&self, data_type: &DataType) -> PolarsResult { match (self.dtype(), data_type) { - #[cfg(feature="dtype-date")] + #[cfg(feature = "dtype-date")] (DataType::Date, DataType::String) => Ok(self .0 .clone() @@ -259,7 +238,7 @@ macro_rules! impl_dyn_series { .unwrap() .to_string("%Y-%m-%d") .into_series()), - #[cfg(feature="dtype-time")] + #[cfg(feature = "dtype-time")] (DataType::Time, DataType::String) => Ok(self .0 .clone() @@ -269,18 +248,11 @@ macro_rules! impl_dyn_series { .to_string("%T") .into_series()), #[cfg(feature = "dtype-datetime")] - (DataType::Time, DataType::Datetime(_, _)) => { - polars_bail!( - ComputeError: - "cannot cast `Time` to `Datetime`; consider using 'dt.combine'" - ); - } - #[cfg(feature = "dtype-datetime")] (DataType::Date, DataType::Datetime(_, _)) => { let mut out = self.0.cast(data_type)?; out.set_sorted_flag(self.0.is_sorted_flag()); Ok(out) - } + }, _ => self.0.cast(data_type), } } @@ -310,17 +282,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } -#[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| ca.$into_logical().into_series()) } -#[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } -#[cfg(feature = "algorithm_group_by")] + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } @@ -355,6 +327,9 @@ macro_rules! impl_dyn_series { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } }; } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 0985ad36bb4c..c4c8bfe1b47b 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -1,11 +1,4 @@ -use std::borrow::Cow; -use std::ops::Deref; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; +use super::*; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; @@ -205,20 +198,6 @@ impl SeriesTrait for SeriesWrap { }) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let ca = self.0.deref().take_chunked_unchecked(by, sorted); - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let ca = self.0.deref().take_opt_chunked_unchecked(by); - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { let ca = self.0.take(indices)?; Ok(ca @@ -381,4 +360,7 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 71c2380a03b6..6a270965efd7 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -1,4 +1,4 @@ -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::prelude::*; unsafe impl IntoSeries for DecimalChunked { @@ -151,18 +151,6 @@ impl SeriesTrait for SeriesWrap { .into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let ca = self.0.deref().take_chunked_unchecked(by, sorted); - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.apply_physical(|ca| ca.take_opt_chunked_unchecked(by)) - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self .0 @@ -269,4 +257,7 @@ impl SeriesTrait for SeriesWrap { Int128Chunked::from_slice_options(self.name(), &[max]) })) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index bbb1b662c317..30e3f30857e0 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -1,12 +1,7 @@ -use std::borrow::Cow; -use std::ops::{Deref, DerefMut}; +use std::ops::DerefMut; -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; @@ -242,6 +237,14 @@ impl SeriesTrait for SeriesWrap { self.0.median() } + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) + } + + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) + } + fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let other = other.to_physical_repr().into_owned(); @@ -262,18 +265,6 @@ impl SeriesTrait for SeriesWrap { .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let ca = self.0.deref().take_chunked_unchecked(by, sorted); - ca.into_duration(self.0.time_unit()).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let ca = self.0.deref().take_opt_chunked_unchecked(by); - ca.into_duration(self.0.time_unit()).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self .0 @@ -419,19 +410,14 @@ impl SeriesTrait for SeriesWrap { fn var_as_series(&self, ddof: u8) -> PolarsResult { Ok(self .0 + .cast_time_unit(TimeUnit::Milliseconds) .var_as_series(ddof) .cast(&self.dtype().to_physical()) .unwrap() - .into_duration(self.0.time_unit())) + .into_duration(TimeUnit::Milliseconds)) } fn median_as_series(&self) -> PolarsResult { - Ok(self - .0 - .median_as_series() - .cast(&self.dtype().to_physical()) - .unwrap() - .cast(self.dtype()) - .unwrap()) + Series::new(self.name(), &[self.median().map(|v| v as i64)]).cast(self.dtype()) } fn quantile_as_series( &self, @@ -448,4 +434,7 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 093d639bc68a..3332649da16b 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -1,21 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::aggregate::{ChunkAggSeries, QuantileAggSeries, VarAggSeries}; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -#[cfg(feature = "checked_arithmetic")] -use crate::series::arithmetic::checked::NumOpsDispatchChecked; macro_rules! impl_dyn_series { ($ca: ident) => { @@ -193,14 +180,12 @@ macro_rules! impl_dyn_series { self.0.median().map(|v| v as f64) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) } fn take(&self, indices: &IdxCa) -> PolarsResult { @@ -329,6 +314,9 @@ macro_rules! impl_dyn_series { fn checked_div(&self, rhs: &Series) -> PolarsResult { self.0.checked_div(rhs) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } }; } diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index fa58b02fbb00..da0c2ce27366 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -1,20 +1,8 @@ -use std::any::Any; -use std::borrow::Cow; - -#[cfg(feature = "group_by_list")] -use ahash::RandomState; - use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::{AsSinglePtr, Settings}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; -#[cfg(feature = "chunked_ids")] -use crate::series::IsSorted; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { @@ -117,16 +105,6 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -188,7 +166,7 @@ impl SeriesTrait for SeriesWrap { } let main_thread = POOL.current_thread_index().is_none(); let groups = self.group_tuples(main_thread, false); - // safety: + // SAFETY: // groups are in bounds Ok(unsafe { self.0.clone().into_series().agg_first(&groups?) }) } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 61d780ca6ac4..dce50670131c 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -28,24 +28,17 @@ mod struct_; use std::any::Any; use std::borrow::Cow; -use std::ops::{BitAnd, BitOr, BitXor, Deref}; +use std::ops::{BitAnd, BitOr, BitXor}; use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::aggregate::{ChunkAggSeries, QuantileAggSeries, VarAggSeries}; use crate::chunked_array::ops::compare_inner::{ IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, }; use crate::chunked_array::ops::explode::ExplodeByOffsets; -#[cfg(feature = "chunked_ids")] -use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::AsSinglePtr; -use crate::prelude::*; -#[cfg(feature = "checked_arithmetic")] -use crate::series::arithmetic::checked::NumOpsDispatchChecked; // Utility wrapper struct pub(crate) struct SeriesWrap(pub T); @@ -290,14 +283,12 @@ macro_rules! impl_dyn_series { self.0.median() } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) } fn take(&self, indices: &IdxCa) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index b0e4a7d6e26a..6c89f1b8ac3c 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -1,12 +1,9 @@ -use std::borrow::Cow; -use std::sync::Arc; +use std::any::Any; -use arrow::array::ArrayRef; use polars_error::constants::LENGTH_LIMIT_MSG; use polars_utils::IdxSize; -use crate::datatypes::IdxCa; -use crate::error::PolarsResult; +use crate::prelude::compare_inner::{IntoTotalEqInner, TotalEqInner}; use crate::prelude::explode::ExplodeByOffsets; use crate::prelude::*; use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; @@ -81,6 +78,23 @@ impl PrivateSeries for NullChunked { ExplodeByOffsets::explode_by_offsets(self, offsets) } + fn subtract(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "subtract") + } + + fn add_to(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "add_to") + } + fn multiply(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "multiply") + } + fn divide(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "divide") + } + fn remainder(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "remainder") + } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { Ok(if self.is_empty() { @@ -96,6 +110,30 @@ impl PrivateSeries for NullChunked { fn _get_flags(&self) -> Settings { Settings::empty() } + + fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + VecHash::vec_hash(self, random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + VecHash::vec_hash_combine(self, build_hasher, hashes)?; + Ok(()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + IntoTotalEqInner::into_total_eq_inner(self) + } +} + +fn null_arithmetic(lhs: &NullChunked, rhs: &Series, op: &str) -> PolarsResult { + let output_len = match (lhs.len(), rhs.len()) { + (1, len_r) => len_r, + (len_l, 1) => len_l, + (len_l, len_r) if len_l == len_r => len_l, + _ => polars_bail!(ComputeError: "Cannot {:?} two series of different lengths.", op), + }; + Ok(NullChunked::new(lhs.name().into(), output_len).into_series()) } impl SeriesTrait for NullChunked { @@ -118,16 +156,6 @@ impl SeriesTrait for NullChunked { self.chunks.iter().map(|chunk| chunk.len()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], _sorted: IsSorted) -> Series { - NullChunked::new(self.name.clone(), by.len()).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - NullChunked::new(self.name.clone(), by.len()).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) } @@ -185,6 +213,10 @@ impl SeriesTrait for NullChunked { Ok(AnyValue::Null) } + unsafe fn get_unchecked(&self, _index: usize) -> AnyValue { + AnyValue::Null + } + fn slice(&self, offset: i64, length: usize) -> Series { let (chunks, len) = chunkops::slice(&self.chunks, offset, length, self.len()); NullChunked { @@ -232,6 +264,9 @@ impl SeriesTrait for NullChunked { fn clone_inner(&self) -> Arc { Arc::new(self.clone()) } + fn as_any(&self) -> &dyn Any { + self + } } unsafe impl IntoSeries for NullChunked { diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 5e75f3ac59c2..d60b474c376d 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -5,11 +5,7 @@ use ahash::RandomState; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; -#[cfg(feature = "chunked_ids")] -use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::Settings; -#[cfg(feature = "algorithm_group_by")] -use crate::frame::group_by::{GroupsProxy, IntoGroupsProxy}; use crate::prelude::*; use crate::series::implementations::SeriesWrap; use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; @@ -125,16 +121,6 @@ where ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -177,6 +163,9 @@ where fn get(&self, index: usize) -> PolarsResult { ObjectChunked::get_any_value(&self.0, index) } + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + ObjectChunked::get_any_value_unchecked(&self.0, index) + } fn null_count(&self) -> usize { ObjectChunked::null_count(&self.0) } @@ -221,6 +210,14 @@ where ObjectChunked::::get_object(&self.0, index) } + unsafe fn get_object_chunked_unchecked( + &self, + chunk: usize, + index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + ObjectChunked::::get_object_chunked_unchecked(&self.0, chunk, index) + } + fn as_any(&self) -> &dyn Any { &self.0 } diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 9294267458a2..9a8c1b1f6aa4 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -1,18 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { @@ -149,16 +139,6 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0.take_chunked_unchecked(by, sorted).into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0.take_opt_chunked_unchecked(by).into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self.0.take(indices)?.into_series()) } @@ -268,4 +248,7 @@ impl SeriesTrait for SeriesWrap { fn str_concat(&self, delimiter: &str) -> StringChunked { self.0.str_concat(delimiter) } + fn as_any(&self) -> &dyn Any { + &self.0 + } } diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 1e5298a15074..fcf85754aac7 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -1,5 +1,3 @@ -use std::any::Any; - use super::*; use crate::hashing::series_to_hashes; use crate::prelude::*; @@ -32,7 +30,7 @@ impl private::PrivateSeries for SeriesWrap { } fn explode_by_offsets(&self, offsets: &[i64]) -> Series { self.0 - .apply_fields(|s| s.explode_by_offsets(offsets)) + ._apply_fields(|s| s.explode_by_offsets(offsets)) .into_series() } @@ -65,7 +63,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - let df = DataFrame::new_no_checks(vec![]); + let df = DataFrame::empty(); let gb = df .group_by_with_series(self.0.fields().to_vec(), multithreaded, sorted) .unwrap(); @@ -123,7 +121,7 @@ impl SeriesTrait for SeriesWrap { /// When offset is negative the offset is counted from the /// end of the array fn slice(&self, offset: i64, length: usize) -> Series { - let mut out = self.0.apply_fields(|s| s.slice(offset, length)); + let mut out = self.0._apply_fields(|s| s.slice(offset, length)); out.update_chunks(0); out.into_series() } @@ -178,20 +176,6 @@ impl SeriesTrait for SeriesWrap { .map(|ca| ca.into_series()) } - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - self.0 - .apply_fields(|s| s._take_chunked_unchecked(by, sorted)) - .into_series() - } - - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - self.0 - .apply_fields(|s| s._take_opt_chunked_unchecked(by)) - .into_series() - } - fn take(&self, indices: &IdxCa) -> PolarsResult { self.0 .try_apply_fields(|s| s.take(indices)) @@ -200,7 +184,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { self.0 - .apply_fields(|s| s.take_unchecked(indices)) + ._apply_fields(|s| s.take_unchecked(indices)) .into_series() } @@ -212,7 +196,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { self.0 - .apply_fields(|s| s.take_slice_unchecked(indices)) + ._apply_fields(|s| s.take_slice_unchecked(indices)) .into_series() } @@ -230,7 +214,7 @@ impl SeriesTrait for SeriesWrap { fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 - .apply_fields(|s| s.new_from_index(index, length)) + ._apply_fields(|s| s.new_from_index(index, length)) .into_series() } @@ -260,7 +244,7 @@ impl SeriesTrait for SeriesWrap { } let main_thread = POOL.current_thread_index().is_none(); let groups = self.group_tuples(main_thread, false); - // safety: + // SAFETY: // groups are in bounds Ok(unsafe { self.0.clone().into_series().agg_first(&groups?) }) } @@ -314,11 +298,11 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.0.apply_fields(|s| s.reverse()).into_series() + self.0._apply_fields(|s| s.reverse()).into_series() } fn shift(&self, periods: i64) -> Series { - self.0.apply_fields(|s| s.shift(periods)).into_series() + self.0._apply_fields(|s| s.shift(periods)).into_series() } fn clone_inner(&self) -> Arc { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index ee56239c7f7e..3d7236b625d6 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -16,7 +16,6 @@ pub mod unstable; use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::sync::Arc; use ahash::RandomState; use arrow::compute::aggregate::estimated_bytes_size; @@ -221,7 +220,8 @@ impl Series { } pub fn into_frame(self) -> DataFrame { - DataFrame::new_no_checks(vec![self]) + // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(vec![self]) } } /// Rename series. @@ -531,37 +531,6 @@ impl Series { .unwrap() } - /// # Safety - /// This doesn't check any bounds. Null validity is checked. - #[cfg(feature = "chunked_ids")] - pub(crate) unsafe fn _take_chunked_unchecked_threaded( - &self, - chunk_ids: &[ChunkId], - sorted: IsSorted, - rechunk: bool, - ) -> Series { - self.threaded_op(rechunk, chunk_ids.len(), &|offset, len| { - let chunk_ids = &chunk_ids[offset..offset + len]; - Ok(self._take_chunked_unchecked(chunk_ids, sorted)) - }) - .unwrap() - } - - /// # Safety - /// This doesn't check any bounds. Null validity is checked. - #[cfg(feature = "chunked_ids")] - pub(crate) unsafe fn _take_opt_chunked_unchecked_threaded( - &self, - chunk_ids: &[Option], - rechunk: bool, - ) -> Series { - self.threaded_op(rechunk, chunk_ids.len(), &|offset, len| { - let chunk_ids = &chunk_ids[offset..offset + len]; - Ok(self._take_opt_chunked_unchecked(chunk_ids)) - }) - .unwrap() - } - /// Take by index. This operation is clone. /// /// # Notes @@ -939,8 +908,6 @@ where #[cfg(test)] mod test { - use std::convert::TryFrom; - use crate::prelude::*; use crate::series::*; diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index d6225a235627..6441dfe03df4 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -4,8 +4,16 @@ use crate::series::implementations::null::NullChunked; macro_rules! unpack_chunked { ($series:expr, $expected:pat => $ca:ty, $name:expr) => { match $series.dtype() { - $expected => unsafe { - Ok(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) + $expected => { + // Check downcast in debug compiles + #[cfg(debug_assertions)] + { + Ok($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) + } + #[cfg(not(debug_assertions))] + unsafe { + Ok(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) + } }, dt => polars_bail!( SchemaMismatch: "invalid series dtype: expected `{}`, got `{}`", $name, dt, diff --git a/crates/polars-core/src/series/ops/extend.rs b/crates/polars-core/src/series/ops/extend.rs index a33b26c7957e..08a196335f4c 100644 --- a/crates/polars-core/src/series/ops/extend.rs +++ b/crates/polars-core/src/series/ops/extend.rs @@ -3,7 +3,8 @@ use crate::prelude::*; impl Series { /// Extend with a constant value. pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult { - let s = Series::from_any_values("", &[value], false).unwrap(); + // TODO: Use `from_any_values_and_dtype` here instead of casting afterwards + let s = Series::from_any_values("", &[value], true).unwrap(); let s = s.cast(self.dtype())?; let to_append = s.new_from_index(0, n); diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs index 4da16cb83ad4..ad2b8e2a221f 100644 --- a/crates/polars-core/src/series/ops/null.rs +++ b/crates/polars-core/src/series/ops/null.rs @@ -39,10 +39,7 @@ impl Series { .into_series(), #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => Int128Chunked::full_null(name, size) - .into_decimal_unchecked( - *precision, - scale.unwrap_or_else(|| unreachable!("scale should be set")), - ) + .into_decimal_unchecked(*precision, scale.unwrap_or(0)) .into_series(), #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { diff --git a/crates/polars-core/src/series/ops/to_list.rs b/crates/polars-core/src/series/ops/to_list.rs index e35cc7a3c93c..3b1c7f757be7 100644 --- a/crates/polars-core/src/series/ops/to_list.rs +++ b/crates/polars-core/src/series/ops/to_list.rs @@ -126,7 +126,6 @@ impl Series { #[cfg(test)] mod test { use super::*; - use crate::chunked_array::builder::get_list_builder; #[test] fn test_to_list() -> PolarsResult<()> { diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 0caaa8ed97b3..eb976d0d9b78 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -1,9 +1,6 @@ use std::any::Any; use std::borrow::Cow; -#[cfg(feature = "temporal")] -use std::sync::Arc; -use arrow::legacy::prelude::QuantileInterpolOptions; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -20,7 +17,7 @@ pub enum IsSorted { } impl IsSorted { - pub(crate) fn reverse(self) -> Self { + pub fn reverse(self) -> Self { use IsSorted::*; match self { Ascending => Descending, @@ -46,8 +43,6 @@ pub(crate) mod private { use super::*; use crate::chunked_array::ops::compare_inner::{TotalEqInner, TotalOrdInner}; use crate::chunked_array::Settings; - #[cfg(feature = "algorithm_group_by")] - use crate::frame::group_by::GroupsProxy; pub trait PrivateSeriesNumeric { fn bit_repr_is_large(&self) -> bool { @@ -212,6 +207,7 @@ pub trait SeriesTrait: fn chunks(&self) -> &Vec; /// Underlying chunks. + /// /// # Safety /// The caller must ensure the length and the data types of `ArrayRef` does not change. unsafe fn chunks_mut(&mut self) -> &mut Vec; @@ -246,14 +242,6 @@ pub trait SeriesTrait: /// Filter by boolean mask. This operation clones data. fn filter(&self, _filter: &BooleanChunked) -> PolarsResult; - #[doc(hidden)] - #[cfg(feature = "chunked_ids")] - unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series; - - #[doc(hidden)] - #[cfg(feature = "chunked_ids")] - unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series; - /// Take by index. This operation is clone. fn take(&self, _indices: &IdxCa) -> PolarsResult; @@ -298,6 +286,18 @@ pub trait SeriesTrait: None } + /// Returns the std value in the array + /// Returns an option because the array is nullable. + fn std(&self, _ddof: u8) -> Option { + None + } + + /// Returns the var value in the array + /// Returns an option because the array is nullable. + fn var(&self, _ddof: u8) -> Option { + None + } + /// Returns the median value in the array /// Returns an option because the array is nullable. fn median(&self) -> Option { @@ -453,12 +453,22 @@ pub trait SeriesTrait: invalid_operation_panic!(get_object, self) } - /// Get a hold to self as `Any` trait reference. - /// Only implemented for ObjectType - fn as_any(&self) -> &dyn Any { - invalid_operation_panic!(as_any, self) + #[cfg(feature = "object")] + /// Get the value at this index as a downcastable Any trait ref. + /// + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn get_object_chunked_unchecked( + &self, + _chunk: usize, + _index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + invalid_operation_panic!(get_object_chunked_unchecked, self) } + /// Get a hold to self as `Any` trait reference. + fn as_any(&self) -> &dyn Any; + /// Get a hold to self as `Any` trait reference. /// Only implemented for ObjectType fn as_any_mut(&mut self) -> &mut dyn Any { diff --git a/crates/polars-core/src/series/unstable.rs b/crates/polars-core/src/series/unstable.rs index ef85d673a6dc..d4bed3cb1035 100644 --- a/crates/polars-core/src/series/unstable.rs +++ b/crates/polars-core/src/series/unstable.rs @@ -45,6 +45,7 @@ impl<'a> UnstableSeries<'a> { } /// Creates a new `[UnsafeSeries]` + /// /// # Safety /// Inner chunks must be from `Series` otherwise the dtype may be incorrect and lead to UB. #[inline] diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index f2f4829c5cbd..82003da6f0c2 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -45,7 +45,7 @@ impl Series { pub fn get_data_ptr(&self) -> usize { let object = self.0.deref(); - // Safety: + // SAFETY: // A fat pointer consists of a data ptr and a ptr to the vtable. // we specifically check that we only transmute &dyn SeriesTrait e.g. // a trait object, therefore this is sound. diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs index 7b5b56bde98b..fdefb3220b19 100644 --- a/crates/polars-core/src/utils/flatten.rs +++ b/crates/polars-core/src/utils/flatten.rs @@ -4,20 +4,20 @@ use super::*; pub fn flatten_df_iter(df: &DataFrame) -> impl Iterator + '_ { df.iter_chunks_physical().flat_map(|chunk| { - let df = DataFrame::new_no_checks( - df.iter() - .zip(chunk.into_arrays()) - .map(|(s, arr)| { - // Safety: - // datatypes are correct - let mut out = unsafe { - Series::from_chunks_and_dtype_unchecked(s.name(), vec![arr], s.dtype()) - }; - out.set_sorted_flag(s.is_sorted_flag()); - out - }) - .collect(), - ); + let columns = df + .iter() + .zip(chunk.into_arrays()) + .map(|(s, arr)| { + // SAFETY: + // datatypes are correct + let mut out = unsafe { + Series::from_chunks_and_dtype_unchecked(s.name(), vec![arr], s.dtype()) + }; + out.set_sorted_flag(s.is_sorted_flag()); + out + }) + .collect(); + let df = unsafe { DataFrame::new_no_checks(columns) }; if df.height() == 0 { None } else { diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index a9f68c58693d..3d8f9a9338c7 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -133,7 +133,11 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult> { split_array!(s, n, i64) } -pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult> { +pub fn split_df_as_ref( + df: &DataFrame, + n: usize, + extend_sub_chunks: bool, +) -> PolarsResult> { let total_len = df.height(); let chunk_size = std::cmp::max(total_len / n, 1); @@ -155,7 +159,7 @@ pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult> chunk_size }; let df = df.slice((i * chunk_size) as i64, len); - if df.n_chunks() > 1 { + if extend_sub_chunks && df.n_chunks() > 1 { // we add every chunk as separate dataframe. This make sure that every partition // deals with it. out.extend(flatten_df_iter(&df)) @@ -175,7 +179,7 @@ pub fn split_df(df: &mut DataFrame, n: usize) -> PolarsResult> { } // make sure that chunks are aligned. df.align_chunks(); - split_df_as_ref(df, n) + split_df_as_ref(df, n, true) } pub fn slice_slice(vals: &[T], offset: i64, len: usize) -> &[T] { @@ -561,6 +565,21 @@ pub fn get_time_units(tu_l: &TimeUnit, tu_r: &TimeUnit) -> TimeUnit { } } +pub fn accumulate_dataframes_vertical_unchecked_optional(dfs: I) -> Option +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next()?; + acc_df.reserve_chunks(additional); + + for df in iter { + acc_df.vstack_mut_unchecked(&df); + } + Some(acc_df) +} + /// This takes ownership of the DataFrame so that drop is called earlier. /// Does not check if schema is correct pub fn accumulate_dataframes_vertical_unchecked(dfs: I) -> DataFrame diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index 6a107d595a48..2c8f5e65ca73 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -2,7 +2,7 @@ use crate::prelude::*; use crate::series::unstable::UnstableSeries; use crate::series::IsSorted; -/// Transform to physical type and coerce floating point and similar sized integer to a bit representation +/// Transform to physical type and coerce similar sized integer to a bit representation /// to reduce compiler bloat pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { s.iter() @@ -11,8 +11,6 @@ pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { match physical.dtype() { DataType::Int64 => physical.bit_repr_large().into_series(), DataType::Int32 => physical.bit_repr_small().into_series(), - DataType::Float32 => physical.bit_repr_small().into_series(), - DataType::Float64 => physical.bit_repr_large().into_series(), _ => physical.into_owned(), } }) diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 2f0288e52c73..f6878fe419bc 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -264,13 +264,13 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { Some(Struct(new_fields)) } #[cfg(feature = "dtype-decimal")] - (d @ Decimal(_, _), dt) if dt.is_signed_integer() || dt.is_unsigned_integer() => Some(d.clone()), - #[cfg(feature = "dtype-decimal")] (Decimal(p1, s1), Decimal(p2, s2)) => { Some(Decimal((*p1).zip(*p2).map(|(p1, p2)| p1.max(p2)), (*s1).max(*s2))) } #[cfg(feature = "dtype-decimal")] (Decimal(_, _), f @ (Float32 | Float64)) => Some(f.clone()), + #[cfg(feature = "dtype-decimal")] + (d @ Decimal(_, _), dt) if dt.is_signed_integer() || dt.is_unsigned_integer() => Some(d.clone()), _ => None, } } diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 95e47cf55264..f0573aee8ed1 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -74,6 +74,10 @@ avro = ["arrow/io_avro", "arrow/io_avro_compression"] csv = ["atoi_simd", "polars-core/rows", "itoa", "ryu", "fast-float", "simdutf8"] decompress = ["flate2/rust_backend", "zstd"] decompress-fast = ["flate2/zlib-ng", "zstd"] +dtype-u8 = ["polars-core/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16"] +dtype-i8 = ["polars-core/dtype-i8"] +dtype-i16 = ["polars-core/dtype-i16"] dtype-categorical = ["polars-core/dtype-categorical"] dtype-date = ["polars-core/dtype-date", "polars-time/dtype-date"] object = [] diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index 3d0316f60ab7..0b290c275962 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -4,7 +4,6 @@ //! [parquet2]: https://crates.io/crates/parquet2 use std::io::{self}; use std::pin::Pin; -use std::sync::Arc; use std::task::Poll; use bytes::Bytes; @@ -12,8 +11,8 @@ use futures::executor::block_on; use futures::future::BoxFuture; use futures::{AsyncRead, AsyncSeek, Future, TryFutureExt}; use object_store::path::Path; -use object_store::{MultipartId, ObjectStore}; -use polars_error::{to_compute_err, PolarsError, PolarsResult}; +use object_store::MultipartId; +use polars_error::to_compute_err; use tokio::io::{AsyncWrite, AsyncWriteExt}; use super::*; @@ -226,7 +225,6 @@ impl Drop for CloudWriter { #[cfg(feature = "csv")] #[cfg(test)] mod tests { - use object_store::ObjectStore; use polars_core::df; use polars_core::prelude::{DataFrame, NamedFrom}; diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs index f59a236b0966..fd187854c91d 100644 --- a/crates/polars-io/src/cloud/glob.rs +++ b/crates/polars-io/src/cloud/glob.rs @@ -3,7 +3,7 @@ use futures::future::ready; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use polars_core::error::to_compute_err; -use polars_core::prelude::{polars_ensure, polars_err, PolarsError, PolarsResult}; +use polars_core::prelude::{polars_ensure, polars_err}; use regex::Regex; use url::Url; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index 64860be34183..94fe4563f9d3 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -1,17 +1,17 @@ use once_cell::sync::Lazy; pub use options::*; use polars_error::to_compute_err; +use polars_utils::aliases::PlHashMap; use tokio::sync::RwLock; +use url::Url; use super::*; -type CacheKey = (String, Option); - -/// A very simple cache that only stores a single object-store. -/// This greatly reduces the query times as multiple object stores (when reading many small files) +/// Object stores must be cached. Every object-store will do DNS lookups and /// get rate limited when querying the DNS (can take up to 5s). +/// Other reasons are connection pools that must be shared between as much as possible. #[allow(clippy::type_complexity)] -static OBJECT_STORE_CACHE: Lazy)>>> = +static OBJECT_STORE_CACHE: Lazy>>> = Lazy::new(Default::default); type BuildResult = PolarsResult<(CloudLocation, Arc)>; @@ -24,26 +24,31 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { ); } +/// Get the key of a url for object store registration. +/// The credential info will be removed +fn url_to_key(url: &Url) -> String { + format!( + "{}://{}", + url.scheme(), + &url[url::Position::BeforeHost..url::Position::AfterPort], + ) +} + /// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { let parsed = parse_url(url).map_err(to_compute_err)?; let cloud_location = CloudLocation::from_url(&parsed)?; - let options = options.cloned(); - let key = (url.to_string(), options); + let key = url_to_key(&parsed); { let cache = OBJECT_STORE_CACHE.read().await; - if let Some((stored_key, store)) = cache.as_ref() { - if stored_key == &key { - return Ok((cloud_location, store.clone())); - } + if let Some(store) = cache.get(&key) { + return Ok((cloud_location, store.clone())); } } - let options = key - .1 - .as_ref() + let options = options .map(Cow::Borrowed) .unwrap_or_else(|| Cow::Owned(Default::default())); @@ -98,6 +103,10 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu }, }?; let mut cache = OBJECT_STORE_CACHE.write().await; - *cache = Some((key, store.clone())); + // Clear the cache if we surpass a certain amount of buckets. Don't expect that to happen. + if cache.len() > 512 { + cache.clear() + } + cache.insert(key, store.clone()); Ok((cloud_location, store)) } diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index 5338682d477d..c744b0b2b12d 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -24,7 +24,6 @@ use object_store::ObjectStore; use object_store::{BackoffConfig, RetryConfig}; #[cfg(feature = "aws")] use once_cell::sync::Lazy; -use polars_core::error::{PolarsError, PolarsResult}; use polars_error::*; #[cfg(feature = "aws")] use polars_utils::cache::FastFixedCache; @@ -37,6 +36,8 @@ use smartstring::alias::String as SmartString; #[cfg(feature = "cloud")] use url::Url; +#[cfg(feature = "aws")] +use crate::pl_async::with_concurrency_budget; #[cfg(feature = "aws")] use crate::utils::resolve_homedir; @@ -284,13 +285,16 @@ impl CloudOptions { builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1"); } else { polars_warn!("'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."); - let result = reqwest::Client::builder() - .build() - .unwrap() - .head(format!("https://{bucket}.s3.amazonaws.com")) - .send() - .await - .map_err(to_compute_err)?; + let result = with_concurrency_budget(1, || async { + reqwest::Client::builder() + .build() + .unwrap() + .head(format!("https://{bucket}.s3.amazonaws.com")) + .send() + .await + .map_err(to_compute_err) + }) + .await?; if let Some(region) = result.headers().get("x-amz-bucket-region") { let region = std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?; diff --git a/crates/polars-io/src/csv/buffer.rs b/crates/polars-io/src/csv/buffer.rs index 130e5b65851e..59852c6e47fa 100644 --- a/crates/polars-io/src/csv/buffer.rs +++ b/crates/polars-io/src/csv/buffer.rs @@ -29,6 +29,20 @@ impl PrimitiveParser for Float64Type { } } +#[cfg(feature = "dtype-u8")] +impl PrimitiveParser for UInt8Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-u16")] +impl PrimitiveParser for UInt16Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} impl PrimitiveParser for UInt32Type { #[inline] fn parse(bytes: &[u8]) -> Option { @@ -41,6 +55,20 @@ impl PrimitiveParser for UInt64Type { atoi_simd::parse_skipped(bytes).ok() } } +#[cfg(feature = "dtype-i8")] +impl PrimitiveParser for Int8Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-i16")] +impl PrimitiveParser for Int16Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} impl PrimitiveParser for Int32Type { #[inline] fn parse(bytes: &[u8]) -> Option { @@ -169,7 +197,7 @@ impl ParsedBuffer for Utf8Field { self.scratch.reserve(bytes.len()); polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); - // Safety: + // SAFETY: // we just allocated enough capacity and data_len is correct. unsafe { let n_written = @@ -251,7 +279,7 @@ impl CategoricalField { polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); self.escape_scratch.clear(); self.escape_scratch.reserve(bytes.len()); - // Safety: + // SAFETY: // we just allocated enough capacity and data_len is correct. unsafe { let n_written = escape_field( @@ -262,12 +290,12 @@ impl CategoricalField { self.escape_scratch.set_len(n_written); } - // safety: + // SAFETY: // just did utf8 check let key = unsafe { std::str::from_utf8_unchecked(&self.escape_scratch) }; self.builder.append_value(key); } else { - // safety: + // SAFETY: // just did utf8 check unsafe { self.builder @@ -343,7 +371,7 @@ where DatetimeInfer: TryFromWithUnit, { let val = if bytes.is_ascii() { - // Safety: + // SAFETY: // we just checked it is ascii unsafe { std::str::from_utf8_unchecked(bytes) } } else { @@ -369,7 +397,7 @@ where buf.builder.append_null(); return Ok(()); } else { - polars_bail!(ComputeError: "could not find a 'date/datetime' pattern for {}", val) + polars_bail!(ComputeError: "could not find a 'date/datetime' pattern for '{}'", val) } }, }, @@ -377,8 +405,17 @@ where match DatetimeInfer::try_from_with_unit(pattern, time_unit) { Ok(mut infer) => { let parsed = infer.parse(val); + let Some(parsed) = parsed else { + if ignore_errors { + buf.builder.append_null(); + return Ok(()); + } else { + polars_bail!(ComputeError: "could not parse '{}' with pattern '{:?}'", val, pattern) + } + }; + buf.compiled = Some(infer); - buf.builder.append_option(parsed); + buf.builder.append_value(parsed); Ok(()) }, Err(err) => { @@ -407,10 +444,16 @@ where _missing_is_null: bool, time_unit: Option, ) -> PolarsResult<()> { - if needs_escaping && bytes.len() > 2 { + if needs_escaping && bytes.len() >= 2 { bytes = &bytes[1..bytes.len() - 1] } + if bytes.is_empty() { + // for types other than string `_missing_is_null` is irrelevant; we always append null + self.builder.append_null(); + return Ok(()); + } + match &mut self.compiled { None => slow_datetime_parser(self, bytes, time_unit, ignore_errors), Some(compiled) => { @@ -442,8 +485,16 @@ pub(crate) fn init_buffers( let (name, dtype) = schema.get_at_index(i).unwrap(); let builder = match dtype { &DataType::Boolean => Buffer::Boolean(BooleanChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-i8")] + &DataType::Int8 => Buffer::Int8(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-i16")] + &DataType::Int16 => Buffer::Int16(PrimitiveChunkedBuilder::new(name, capacity)), &DataType::Int32 => Buffer::Int32(PrimitiveChunkedBuilder::new(name, capacity)), &DataType::Int64 => Buffer::Int64(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-u8")] + &DataType::UInt8 => Buffer::UInt8(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-u16")] + &DataType::UInt16 => Buffer::UInt16(PrimitiveChunkedBuilder::new(name, capacity)), &DataType::UInt32 => Buffer::UInt32(PrimitiveChunkedBuilder::new(name, capacity)), &DataType::UInt64 => Buffer::UInt64(PrimitiveChunkedBuilder::new(name, capacity)), &DataType::Float32 => Buffer::Float32(PrimitiveChunkedBuilder::new(name, capacity)), @@ -476,8 +527,16 @@ pub(crate) fn init_buffers( #[allow(clippy::large_enum_variant)] pub(crate) enum Buffer { Boolean(BooleanChunkedBuilder), + #[cfg(feature = "dtype-i8")] + Int8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-i16")] + Int16(PrimitiveChunkedBuilder), Int32(PrimitiveChunkedBuilder), Int64(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u8")] + UInt8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u16")] + UInt16(PrimitiveChunkedBuilder), UInt32(PrimitiveChunkedBuilder), UInt64(PrimitiveChunkedBuilder), Float32(PrimitiveChunkedBuilder), @@ -500,8 +559,16 @@ impl Buffer { pub(crate) fn into_series(self) -> PolarsResult { let s = match self { Buffer::Boolean(v) => v.finish().into_series(), + #[cfg(feature = "dtype-i8")] + Buffer::Int8(v) => v.finish().into_series(), + #[cfg(feature = "dtype-i16")] + Buffer::Int16(v) => v.finish().into_series(), Buffer::Int32(v) => v.finish().into_series(), Buffer::Int64(v) => v.finish().into_series(), + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(v) => v.finish().into_series(), + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(v) => v.finish().into_series(), Buffer::UInt32(v) => v.finish().into_series(), Buffer::UInt64(v) => v.finish().into_series(), Buffer::Float32(v) => v.finish().into_series(), @@ -547,8 +614,16 @@ impl Buffer { pub(crate) fn add_null(&mut self, valid: bool) { match self { Buffer::Boolean(v) => v.append_null(), + #[cfg(feature = "dtype-i8")] + Buffer::Int8(v) => v.append_null(), + #[cfg(feature = "dtype-i16")] + Buffer::Int16(v) => v.append_null(), Buffer::Int32(v) => v.append_null(), Buffer::Int64(v) => v.append_null(), + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(v) => v.append_null(), + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(v) => v.append_null(), Buffer::UInt32(v) => v.append_null(), Buffer::UInt64(v) => v.append_null(), Buffer::Float32(v) => v.append_null(), @@ -581,8 +656,16 @@ impl Buffer { pub(crate) fn dtype(&self) -> DataType { match self { Buffer::Boolean(_) => DataType::Boolean, + #[cfg(feature = "dtype-i8")] + Buffer::Int8(_) => DataType::Int8, + #[cfg(feature = "dtype-i16")] + Buffer::Int16(_) => DataType::Int16, Buffer::Int32(_) => DataType::Int32, Buffer::Int64(_) => DataType::Int64, + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(_) => DataType::UInt8, + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(_) => DataType::UInt16, Buffer::UInt32(_) => DataType::UInt32, Buffer::UInt64(_) => DataType::UInt64, Buffer::Float32(_) => DataType::Float32, @@ -624,6 +707,24 @@ impl Buffer { missing_is_null, None, ), + #[cfg(feature = "dtype-i8")] + Int8(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-i16")] + Int16(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), Int32(buf) => as ParsedBuffer>::parse_bytes( buf, bytes, @@ -640,7 +741,17 @@ impl Buffer { missing_is_null, None, ), - UInt64(buf) => as ParsedBuffer>::parse_bytes( + #[cfg(feature = "dtype-u8")] + UInt8(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-u16")] + UInt16(buf) => as ParsedBuffer>::parse_bytes( buf, bytes, ignore_errors, @@ -656,6 +767,14 @@ impl Buffer { missing_is_null, None, ), + UInt64(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), Float32(buf) => as ParsedBuffer>::parse_bytes( buf, bytes, diff --git a/crates/polars-io/src/csv/mod.rs b/crates/polars-io/src/csv/mod.rs index 4eaf0efbd73c..fba65a0f719f 100644 --- a/crates/polars-io/src/csv/mod.rs +++ b/crates/polars-io/src/csv/mod.rs @@ -54,6 +54,7 @@ use std::fs::File; use std::io::Write; use std::path::PathBuf; +pub use parser::count_rows; use polars_core::prelude::*; #[cfg(feature = "temporal")] use polars_time::prelude::*; diff --git a/crates/polars-io/src/csv/parser.rs b/crates/polars-io/src/csv/parser.rs index 715d74d8b4b6..cbc599427fa1 100644 --- a/crates/polars-io/src/csv/parser.rs +++ b/crates/polars-io/src/csv/parser.rs @@ -1,11 +1,73 @@ +use std::path::PathBuf; + use memchr::memchr2_iter; use num_traits::Pow; use polars_core::prelude::*; +use polars_core::POOL; +use polars_utils::index::Bounded; +use rayon::prelude::*; use super::buffer::*; use crate::csv::read::NullValuesCompiled; use crate::csv::splitfields::SplitFields; +use crate::csv::utils::get_file_chunks; use crate::csv::CommentPrefix; +use crate::utils::get_reader_bytes; + +/// Read the number of rows without parsing columns +/// useful for count(*) queries +pub fn count_rows( + path: &PathBuf, + separator: u8, + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + let mut reader = polars_utils::open_file(path)?; + let reader_bytes = get_reader_bytes(&mut reader)?; + const MIN_ROWS_PER_THREAD: usize = 1024; + let max_threads = POOL.current_num_threads(); + + // Determine if parallelism is beneficial and how many threads + let n_threads = get_line_stats( + &reader_bytes, + MIN_ROWS_PER_THREAD, + eol_char, + None, + separator, + quote_char, + ) + .map(|(mean, std)| { + let n_rows = (reader_bytes.len() as f32 / (mean - 0.01 * std)) as usize; + (n_rows / MIN_ROWS_PER_THREAD).clamp(1, max_threads) + }) + .unwrap_or(1); + + let file_chunks = get_file_chunks( + &reader_bytes, + n_threads, + None, + separator, + quote_char, + eol_char, + ); + + let iter = file_chunks.into_par_iter().map(|(start, stop)| { + let local_bytes = &reader_bytes[start..stop]; + let row_iterator = SplitLines::new(local_bytes, quote_char.unwrap_or(b'"'), eol_char); + if comment_prefix.is_some() { + Ok(row_iterator + .filter(|line| !line.is_empty() && !is_comment_line(line, comment_prefix)) + .count() + - (has_header as usize)) + } else { + Ok(row_iterator.count() - (has_header as usize)) + } + }); + + POOL.install(|| iter.sum()) +} /// Skip the utf-8 Byte Order Mark. /// credits to csv-core @@ -183,7 +245,7 @@ pub(crate) fn get_line_stats( bytes: &[u8], n_lines: usize, eol_char: u8, - expected_fields: usize, + expected_fields: Option, separator: u8, quote_char: Option, ) -> Option<(f32, f32)> { @@ -199,7 +261,7 @@ pub(crate) fn get_line_stats( bytes_trunc = &bytes[offset..]; let pos = next_line_position( bytes_trunc, - Some(expected_fields), + expected_fields, separator, quote_char, eol_char, @@ -438,7 +500,7 @@ pub(super) fn parse_lines( field }; - // safety: + // SAFETY: // process fields is in bounds add_null = unsafe { null_values.is_null(field, processed_fields) } } diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 53892463a497..6168fa620bb9 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -73,7 +73,8 @@ impl NullValuesCompiled { } } - /// Safety + /// # Safety + /// /// The caller must ensure that `index` is in bounds pub(super) unsafe fn is_null(&self, field: &[u8], index: usize) -> bool { use NullValuesCompiled::*; @@ -435,6 +436,7 @@ impl<'a, R: MmapBytesReader + 'a> CsvReader<'a, R> { let mut _has_categorical = false; let mut _err: Option = None; + #[allow(unused_mut)] let schema = overwriting_schema .iter_fields() .filter_map(|mut fld| { @@ -445,12 +447,6 @@ impl<'a, R: MmapBytesReader + 'a> CsvReader<'a, R> { // let inference decide the column type None }, - Int8 | Int16 | UInt8 | UInt16 => { - // We have not compiled these buffers, so we cast them later. - to_cast.push(fld.clone()); - fld.coerce(DataType::Int32); - Some(fld) - }, #[cfg(feature = "dtype-categorical")] Categorical(_, _) => { _has_categorical = true; @@ -532,6 +528,7 @@ impl<'a> CsvReader<'a, Box> { self.null_values.as_ref(), self.try_parse_dates, self.raise_if_empty, + &mut self.n_threads, )?; let schema = Arc::new(inferred_schema); Ok(to_batched_owned_mmap(self, schema)) @@ -561,6 +558,7 @@ impl<'a> CsvReader<'a, Box> { self.null_values.as_ref(), self.try_parse_dates, self.raise_if_empty, + &mut self.n_threads, )?; let schema = Arc::new(inferred_schema); Ok(to_batched_owned_read(self, schema)) @@ -683,6 +681,8 @@ where #[cfg(feature = "temporal")] fn parse_dates(mut df: DataFrame, fixed_schema: &Schema) -> DataFrame { + use polars_core::POOL; + let cols = unsafe { std::mem::take(df.get_columns_mut()) } .into_par_iter() .map(|s| { @@ -702,8 +702,8 @@ fn parse_dates(mut df: DataFrame, fixed_schema: &Schema) -> DataFrame { }, _ => s, } - }) - .collect::>(); + }); + let cols = POOL.install(|| cols.collect::>()); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } diff --git a/crates/polars-io/src/csv/read_impl/mod.rs b/crates/polars-io/src/csv/read_impl/mod.rs index 88177211913d..05d88ed89036 100644 --- a/crates/polars-io/src/csv/read_impl/mod.rs +++ b/crates/polars-io/src/csv/read_impl/mod.rs @@ -3,7 +3,6 @@ mod batched_read; use std::fmt; use std::ops::Deref; -use std::sync::Arc; pub use batched_mmap::*; pub use batched_read::*; @@ -73,7 +72,7 @@ pub(crate) fn cast_columns( } }) .collect::>>()?; - *df = DataFrame::new_no_checks(cols) + *df = unsafe { DataFrame::new_no_checks(cols) } } else { // cast to the original dtypes in the schema for fld in to_cast { @@ -142,7 +141,7 @@ impl<'a> CoreReader<'a> { schema: Option, columns: Option>, encoding: CsvEncoding, - n_threads: Option, + mut n_threads: Option, schema_overwrite: Option, dtype_overwrite: Option<&'a [DataType]>, sample_size: usize, @@ -208,6 +207,7 @@ impl<'a> CoreReader<'a> { null_values.as_ref(), try_parse_dates, raise_if_empty, + &mut n_threads, )?; Arc::new(inferred_schema) } @@ -290,6 +290,12 @@ impl<'a> CoreReader<'a> { bytes = &bytes[pos..]; } } + + // skip lines that are comments + while is_comment_line(bytes, self.comment_prefix.as_ref()) { + bytes = skip_this_line(bytes, quote_char, eol_char); + } + // skip header row if self.has_header { bytes = skip_this_line(bytes, quote_char, eol_char); @@ -339,7 +345,7 @@ impl<'a> CoreReader<'a> { bytes, self.sample_size, self.eol_char, - self.schema.len(), + Some(self.schema.len()), self.separator, self.quote_char, ) { @@ -419,7 +425,7 @@ impl<'a> CoreReader<'a> { let chunks = get_file_chunks( bytes, n_file_chunks, - self.schema.len(), + Some(self.schema.len()), self.separator, self.quote_char, self.eol_char, @@ -516,8 +522,8 @@ impl<'a> CoreReader<'a> { self.quote_char, self.eol_char, self.missing_is_null, - self.truncate_ragged_lines, ignore_errors, + self.truncate_ragged_lines, self.null_values.as_ref(), projection, &mut buffers, @@ -526,12 +532,11 @@ impl<'a> CoreReader<'a> { &self.schema, )?; - let mut local_df = DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - ); + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + let mut local_df = unsafe { DataFrame::new_no_checks(columns) }; let current_row_count = local_df.height() as IdxSize; if let Some(rc) = &self.row_index { local_df.with_row_index_mut(&rc.name, Some(rc.offset)); @@ -630,12 +635,11 @@ impl<'a> CoreReader<'a> { self.schema.as_ref(), )?; - DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - ) + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + unsafe { DataFrame::new_no_checks(columns) } }; cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; @@ -725,10 +729,9 @@ fn read_chunk( )?; } - Ok(DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - )) + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + Ok(unsafe { DataFrame::new_no_checks(columns) }) } diff --git a/crates/polars-io/src/csv/splitfields.rs b/crates/polars-io/src/csv/splitfields.rs index 81e42f82be89..b510080bcb36 100644 --- a/crates/polars-io/src/csv/splitfields.rs +++ b/crates/polars-io/src/csv/splitfields.rs @@ -64,7 +64,7 @@ mod inner { // There can be strings with separators: // "Street, City", - // Safety: + // SAFETY: // we have checked bounds let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char { needs_escaping = true; @@ -90,7 +90,7 @@ mod inner { if !in_field && self.eof_oel(c) { if c == self.eol_char { - // safety + // SAFETY: // we are in bounds return unsafe { self.finish_eol(needs_escaping, current_idx as usize) @@ -111,7 +111,7 @@ mod inner { match self.v.iter().position(|&c| self.eof_oel(c)) { None => return self.finish(needs_escaping), Some(idx) => unsafe { - // Safety: + // SAFETY: // idx was just found if *self.v.get_unchecked(idx) == self.eol_char { return self.finish_eol(needs_escaping, idx); @@ -124,7 +124,7 @@ mod inner { unsafe { debug_assert!(pos <= self.v.len()); - // safety + // SAFETY: // we are in bounds let ret = Some((self.v.get_unchecked(..pos), needs_escaping)); self.v = self.v.get_unchecked(pos + 1..); @@ -226,7 +226,7 @@ mod inner { // There can be strings with separators: // "Street, City", - // Safety: + // SAFETY: // we have checked bounds let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char { needs_escaping = true; @@ -252,7 +252,7 @@ mod inner { if !in_field && self.eof_oel(c) { if c == self.eol_char { - // safety + // SAFETY: // we are in bounds return unsafe { self.finish_eol(needs_escaping, current_idx as usize) @@ -318,7 +318,7 @@ mod inner { unsafe { debug_assert!(pos <= self.v.len()); - // safety + // SAFETY: // we are in bounds let ret = Some((self.v.get_unchecked(..pos), needs_escaping)); self.v = self.v.get_unchecked(pos + 1..); diff --git a/crates/polars-io/src/csv/utils.rs b/crates/polars-io/src/csv/utils.rs index 7210aa65fa69..64d98a6ff3ee 100644 --- a/crates/polars-io/src/csv/utils.rs +++ b/crates/polars-io/src/csv/utils.rs @@ -3,12 +3,13 @@ use std::borrow::Cow; use std::io::Read; use std::mem::MaybeUninit; -use polars_core::datatypes::PlHashSet; +use polars_core::config::verbose; use polars_core::prelude::*; #[cfg(feature = "polars-time")] use polars_time::chunkedarray::string::infer as date_infer; #[cfg(feature = "polars-time")] use polars_time::prelude::string::Pattern; +use polars_utils::slice::GetSaferUnchecked; #[cfg(any(feature = "decompress", feature = "decompress-fast"))] use crate::csv::parser::next_line_position_naive; @@ -23,7 +24,7 @@ use crate::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE}; pub(crate) fn get_file_chunks( bytes: &[u8], n_chunks: usize, - expected_fields: usize, + expected_fields: Option, separator: u8, quote_char: Option, eol_char: u8, @@ -41,7 +42,7 @@ pub(crate) fn get_file_chunks( let end_pos = match next_line_position( &bytes[search_pos..], - Some(expected_fields), + expected_fields, separator, quote_char, eol_char, @@ -150,6 +151,7 @@ pub fn infer_file_schema_inner( try_parse_dates: bool, recursion_count: u8, raise_if_empty: bool, + n_threads: &mut Option, ) -> PolarsResult<(Schema, usize, usize)> { // keep track so that we can determine the amount of bytes read let start_ptr = reader_bytes.as_ptr() as usize; @@ -249,6 +251,7 @@ pub fn infer_file_schema_inner( try_parse_dates, recursion_count + 1, raise_if_empty, + n_threads, ); } else if !raise_if_empty { return Ok((Schema::new(), 0, 0)); @@ -317,7 +320,7 @@ pub fn infer_file_schema_inner( for i in 0..header_length { if let Some((slice, needs_escaping)) = record.next() { if slice.is_empty() { - nulls[i] = true; + unsafe { *nulls.get_unchecked_release_mut(i) = true }; } else { let slice_escaped = if needs_escaping && (slice.len() >= 2) { &slice[1..(slice.len() - 1)] @@ -325,32 +328,57 @@ pub fn infer_file_schema_inner( slice }; let s = parse_bytes_with_encoding(slice_escaped, encoding)?; - match &null_values { - None => { - column_types[i].insert(infer_field_schema(&s, try_parse_dates)); - }, + let dtype = match &null_values { + None => Some(infer_field_schema(&s, try_parse_dates)), Some(NullValues::AllColumns(names)) => { if !names.iter().any(|nv| nv == s.as_ref()) { - column_types[i].insert(infer_field_schema(&s, try_parse_dates)); + Some(infer_field_schema(&s, try_parse_dates)) + } else { + None } }, Some(NullValues::AllColumnsSingle(name)) => { if s.as_ref() != name { - column_types[i].insert(infer_field_schema(&s, try_parse_dates)); + Some(infer_field_schema(&s, try_parse_dates)) + } else { + None } }, Some(NullValues::Named(names)) => { - let current_name = &headers[i]; + // SAFETY: + // we iterate over headers length. + let current_name = unsafe { headers.get_unchecked_release(i) }; let null_name = &names.iter().find(|name| &name.0 == current_name); if let Some(null_name) = null_name { if null_name.1 != s.as_ref() { - column_types[i].insert(infer_field_schema(&s, try_parse_dates)); + Some(infer_field_schema(&s, try_parse_dates)) + } else { + None } } else { - column_types[i].insert(infer_field_schema(&s, try_parse_dates)); + Some(infer_field_schema(&s, try_parse_dates)) } }, + }; + if let Some(dtype) = dtype { + if matches!(&dtype, DataType::String) + && needs_escaping + && n_threads.unwrap_or(2) > 1 + { + // The parser will chunk the file. + // However this will be increasingly unlikely to be correct if there are many + // new line characters in an escaped field. So we set a (somewhat arbitrary) + // upper bound to the number of escaped lines we accept. + // On the chunking side we also have logic to make this more robust. + if slice.iter().filter(|b| **b == eol_char).count() > 8 { + if verbose() { + eprintln!("falling back to single core reading because of many escaped new line chars.") + } + *n_threads = Some(1); + } + } + unsafe { column_types.get_unchecked_release_mut(i).insert(dtype) }; } } } @@ -392,21 +420,6 @@ pub fn infer_file_schema_inner( { // we have an integer and double, fall down to double fields.push(Field::new(field_name, DataType::Float64)); - } - // prefer a datelike parse above a no parse so choose the date type - else if possibilities.contains(&DataType::String) - && possibilities.contains(&DataType::Date) - { - fields.push(Field::new(field_name, DataType::Date)); - } - // prefer a datelike parse above a no parse so choose the date type - else if possibilities.contains(&DataType::String) - && possibilities.contains(&DataType::Datetime(TimeUnit::Microseconds, None)) - { - fields.push(Field::new( - field_name, - DataType::Datetime(TimeUnit::Microseconds, None), - )); } else { // default to String for conflicting datatypes (e.g bool and int) fields.push(Field::new(field_name, DataType::String)); @@ -441,6 +454,7 @@ pub fn infer_file_schema_inner( try_parse_dates, recursion_count + 1, raise_if_empty, + n_threads, ); } @@ -473,6 +487,7 @@ pub fn infer_file_schema( null_values: Option<&NullValues>, try_parse_dates: bool, raise_if_empty: bool, + n_threads: &mut Option, ) -> PolarsResult<(Schema, usize, usize)> { infer_file_schema_inner( reader_bytes, @@ -489,6 +504,7 @@ pub fn infer_file_schema( try_parse_dates, 0, raise_if_empty, + n_threads, ) } @@ -667,7 +683,11 @@ mod test { let s = std::fs::read_to_string(path).unwrap(); let bytes = s.as_bytes(); // can be within -1 / +1 bounds. - assert!((get_file_chunks(bytes, 10, 4, b',', None, b'\n').len() as i32 - 10).abs() <= 1); - assert!((get_file_chunks(bytes, 8, 4, b',', None, b'\n').len() as i32 - 8).abs() <= 1); + assert!( + (get_file_chunks(bytes, 10, Some(4), b',', None, b'\n').len() as i32 - 10).abs() <= 1 + ); + assert!( + (get_file_chunks(bytes, 8, Some(4), b',', None, b'\n').len() as i32 - 8).abs() <= 1 + ); } } diff --git a/crates/polars-io/src/csv/write.rs b/crates/polars-io/src/csv/write.rs index 061a47c98127..70db39513742 100644 --- a/crates/polars-io/src/csv/write.rs +++ b/crates/polars-io/src/csv/write.rs @@ -165,13 +165,14 @@ where self } - pub fn batched(self, _schema: &Schema) -> PolarsResult> { + pub fn batched(self, schema: &Schema) -> PolarsResult> { let expects_bom = self.bom; let expects_header = self.header; Ok(BatchedWriter { writer: self, has_written_bom: !expects_bom, has_written_header: !expects_header, + schema: schema.clone(), }) } } @@ -180,6 +181,7 @@ pub struct BatchedWriter { writer: CsvWriter, has_written_bom: bool, has_written_header: bool, + schema: Schema, } impl BatchedWriter { @@ -208,4 +210,20 @@ impl BatchedWriter { )?; Ok(()) } + + /// Writes the header of the csv file if not done already. Returns the total size of the file. + pub fn finish(&mut self) -> PolarsResult<()> { + if !self.has_written_bom { + self.has_written_bom = true; + write_impl::write_bom(&mut self.writer.buffer)?; + } + + if !self.has_written_header { + self.has_written_header = true; + let names = self.schema.get_names(); + write_impl::write_header(&mut self.writer.buffer, &names, &self.writer.options)?; + }; + + Ok(()) + } } diff --git a/crates/polars-io/src/csv/write_impl.rs b/crates/polars-io/src/csv/write_impl.rs index 3d1d945e7872..f9c2af3c5194 100644 --- a/crates/polars-io/src/csv/write_impl.rs +++ b/crates/polars-io/src/csv/write_impl.rs @@ -65,7 +65,7 @@ fn write_integer(f: &mut Vec, val: I) { } #[allow(unused_variables)] -unsafe fn write_anyvalue( +unsafe fn write_any_value( f: &mut Vec, value: AnyValue, options: &SerializeOptions, @@ -402,7 +402,7 @@ pub(crate) fn write( df.as_single_chunk(); let cols = df.get_columns(); - // Safety: + // SAFETY: // the bck thinks the lifetime is bounded to write_buffer_pool, but at the time we return // the vectors the buffer pool, the series have already been removed from the buffers // in other words, the lifetime does not leave this scope @@ -425,7 +425,7 @@ pub(crate) fn write( for (i, col) in &mut col_iters.iter_mut().enumerate() { match col.next() { Some(value) => unsafe { - write_anyvalue( + write_any_value( &mut write_buffer, value, options, diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index f57d1132b7dd..671c3cf00ae9 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -33,7 +33,6 @@ //! assert!(df.equals(&df_read)); //! ``` use std::io::{Read, Seek}; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use arrow::io::ipc::read; diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index a36ae625ed41..0e1842712a13 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -214,7 +214,7 @@ fn fix_column_order( iter.collect() }; - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } else { df } diff --git a/crates/polars-io/src/json/infer.rs b/crates/polars-io/src/json/infer.rs index 578d9bc8fadf..0019f98fb5f1 100644 --- a/crates/polars-io/src/json/infer.rs +++ b/crates/polars-io/src/json/infer.rs @@ -1,5 +1,3 @@ -use simd_json::value::BorrowedValue; - use super::*; pub(crate) fn json_values_to_supertype( diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index da7360985dc5..2e1ce2692470 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -64,12 +64,10 @@ //! pub(crate) mod infer; -use std::convert::TryFrom; use std::io::Write; use std::num::NonZeroUsize; use std::ops::Deref; -use arrow::array::{ArrayRef, StructArray}; use arrow::legacy::conversion::chunk_to_struct; use polars_core::error::to_compute_err; use polars_core::prelude::*; diff --git a/crates/polars-io/src/parquet/async_impl.rs b/crates/polars-io/src/parquet/async_impl.rs index 49bf30d45087..a8287069e2b6 100644 --- a/crates/polars-io/src/parquet/async_impl.rs +++ b/crates/polars-io/src/parquet/async_impl.rs @@ -1,14 +1,12 @@ //! Read parquet files in parallel from the Object Store without a third party crate. use std::ops::Range; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use bytes::Bytes; use object_store::path::Path as ObjectPath; use object_store::ObjectStore; use polars_core::config::{get_rg_prefetch_size, verbose}; -use polars_core::datatypes::PlHashMap; -use polars_core::error::{to_compute_err, PolarsResult}; +use polars_core::error::to_compute_err; use polars_core::prelude::*; use polars_parquet::read::{self as parquet2_read, RowGroupMetaData}; use polars_parquet::write::FileMetaData; @@ -80,14 +78,17 @@ impl ParquetObjectStore { if self.length.is_some() { return Ok(()); } - self.length = Some( - self.store - .head(&self.path) - .await - .map_err(to_compute_err)? - .size as u64, - ); - Ok(()) + with_concurrency_budget(1, || async { + self.length = Some( + self.store + .head(&self.path) + .await + .map_err(to_compute_err)? + .size as u64, + ); + Ok(()) + }) + .await } pub async fn schema(&mut self) -> PolarsResult { @@ -112,9 +113,12 @@ impl ParquetObjectStore { let length = self.length; let mut reader = CloudReader::new(length, object_store, path); - parquet2_read::read_metadata_async(&mut reader) - .await - .map_err(to_compute_err) + with_concurrency_budget(1, || async { + parquet2_read::read_metadata_async(&mut reader) + .await + .map_err(to_compute_err) + }) + .await } /// Fetch and memoize the metadata of the parquet file. diff --git a/crates/polars-io/src/parquet/mmap.rs b/crates/polars-io/src/parquet/mmap.rs index c7ece315d1f3..38013da5febe 100644 --- a/crates/polars-io/src/parquet/mmap.rs +++ b/crates/polars-io/src/parquet/mmap.rs @@ -1,8 +1,6 @@ use arrow::datatypes::Field; #[cfg(feature = "async")] use bytes::Bytes; -#[cfg(feature = "async")] -use polars_core::datatypes::PlHashMap; use polars_parquet::read::{ column_iter_to_arrays, get_field_columns, ArrayIter, BasicDecompressor, ColumnChunkMetaData, PageReader, diff --git a/crates/polars-io/src/parquet/predicates.rs b/crates/polars-io/src/parquet/predicates.rs index cffe8c12d7d7..d3775864e1a3 100644 --- a/crates/polars-io/src/parquet/predicates.rs +++ b/crates/polars-io/src/parquet/predicates.rs @@ -31,7 +31,11 @@ pub(crate) fn collect_statistics( Ok(if stats.is_empty() { None } else { - Some(BatchStats::new(Arc::new(schema.into()), stats)) + Some(BatchStats::new( + Arc::new(schema.into()), + stats, + Some(md.num_rows()), + )) }) } diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index cac3347fd464..6038e09c78e0 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -1,12 +1,10 @@ use std::io::{Read, Seek}; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use polars_core::prelude::*; #[cfg(feature = "cloud")] use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_parquet::read; -use polars_parquet::write::FileMetaData; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -162,6 +160,7 @@ impl ParquetReader { chunk_size, self.use_statistics, self.hive_partition_columns, + self.parallel, ) } } @@ -233,6 +232,7 @@ pub struct ParquetAsyncReader { use_statistics: bool, hive_partition_columns: Option>, schema: Option, + parallel: ParallelStrategy, } #[cfg(feature = "cloud")] @@ -253,6 +253,7 @@ impl ParquetAsyncReader { use_statistics: true, hive_partition_columns: None, schema, + parallel: Default::default(), }) } @@ -303,6 +304,11 @@ impl ParquetAsyncReader { self } + pub fn read_parallel(mut self, parallel: ParallelStrategy) -> Self { + self.parallel = parallel; + self + } + pub async fn batched(mut self, chunk_size: usize) -> PolarsResult { let metadata = self.reader.get_metadata().await?.clone(); let schema = match self.schema { @@ -330,6 +336,7 @@ impl ParquetAsyncReader { chunk_size, self.use_statistics, self.hive_partition_columns, + self.parallel, ) } diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index e4fd8e79bf7d..8443c5fb953d 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -1,8 +1,6 @@ use std::borrow::Cow; use std::collections::VecDeque; -use std::convert::TryFrom; use std::ops::{Deref, Range}; -use std::sync::Arc; use arrow::array::new_empty_array; use arrow::datatypes::ArrowSchemaRef; @@ -111,7 +109,10 @@ pub(super) fn array_iter_to_series( /// Materializes hive partitions. /// We have a special num_rows arg, as df can be empty when a projection contains /// only hive partition columns. -/// Safety: num_rows equals the height of the df when the df height is non-zero. +/// +/// # Safety +/// +/// num_rows equals the height of the df when the df height is non-zero. pub(crate) fn materialize_hive_partitions( df: &mut DataFrame, hive_partition_columns: Option<&[Series]>, @@ -245,7 +246,7 @@ fn rg_to_dfs_optionally_par_over_columns( *remaining_rows -= projection_height; - let mut df = DataFrame::new_no_checks(columns); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { df.with_row_index_mut(&rc.name, Some(*previous_row_count + rc.offset)); } @@ -332,7 +333,7 @@ fn rg_to_dfs_par_over_rg( }) .collect::>>()?; - let mut df = DataFrame::new_no_checks(columns); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { df.with_row_index_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); @@ -445,7 +446,7 @@ pub struct FetchRowGroupsFromMmapReader(ReaderBytes<'static>); impl FetchRowGroupsFromMmapReader { pub fn new(mut reader: Box) -> PolarsResult { - // safety we will keep ownership on the struct and reference the bytes on the heap. + // SAFETY: we will keep ownership on the struct and reference the bytes on the heap. // this should not work with passed bytes so we check if it is a file assert!(reader.to_file().is_some()); let reader_ptr = unsafe { @@ -547,16 +548,25 @@ impl BatchedParquetReader { chunk_size: usize, use_statistics: bool, hive_partition_columns: Option>, + mut parallel: ParallelStrategy, ) -> PolarsResult { let n_row_groups = metadata.row_groups.len(); let projection = projection.unwrap_or_else(|| (0usize..schema.len()).collect::>()); - let parallel = - if n_row_groups > projection.len() || n_row_groups > POOL.current_num_threads() { - ParallelStrategy::RowGroups - } else { - ParallelStrategy::Columns - }; + parallel = match parallel { + ParallelStrategy::Auto => { + if n_row_groups > projection.len() || n_row_groups > POOL.current_num_threads() { + ParallelStrategy::RowGroups + } else { + ParallelStrategy::Columns + } + }, + _ => parallel, + }; + + if let (ParallelStrategy::Columns, true) = (parallel, projection.len() == 1) { + parallel = ParallelStrategy::None; + } Ok(BatchedParquetReader { row_group_fetcher, diff --git a/crates/polars-io/src/parquet/write.rs b/crates/polars-io/src/parquet/write.rs index 10694d858781..776effe36f00 100644 --- a/crates/polars-io/src/parquet/write.rs +++ b/crates/polars-io/src/parquet/write.rs @@ -1,19 +1,19 @@ +use std::borrow::Cow; use std::io::Write; -use arrow::array::{Array, ArrayRef}; +use arrow::array::Array; use arrow::chunk::Chunk; -use arrow::datatypes::{ArrowDataType, PhysicalType}; +use arrow::datatypes::PhysicalType; use polars_core::prelude::*; -use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df}; +use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; use polars_core::POOL; use polars_parquet::read::ParquetError; -use polars_parquet::write::{self, DynIter, DynStreamingIterator, Encoding, FileWriter, *}; +use polars_parquet::write::{self, *}; use rayon::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use write::{ - BrotliLevel as BrotliLevelParquet, CompressionOptions, GzipLevel as GzipLevelParquet, - ZstdLevel as ZstdLevelParquet, + BrotliLevel as BrotliLevelParquet, GzipLevel as GzipLevelParquet, ZstdLevel as ZstdLevelParquet, }; #[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] @@ -192,11 +192,26 @@ where df.align_chunks(); let n_splits = df.height() / self.row_group_size.unwrap_or(512 * 512); - if n_splits > 0 { - *df = accumulate_dataframes_vertical_unchecked(split_df(df, n_splits)?); - } + let chunked_df = if n_splits > 0 { + Cow::Owned(accumulate_dataframes_vertical_unchecked( + split_df_as_ref(df, n_splits, false)? + .into_iter() + .map(|mut df| { + // If the chunks are small enough, writing many small chunks + // leads to slow writing performance, so in that case we + // merge them. + let n_chunks = df.n_chunks(); + if n_chunks > 1 && (df.estimated_size() / n_chunks < 128 * 1024) { + df.as_single_chunk_par(); + } + df + }), + )) + } else { + Cow::Borrowed(df) + }; let mut batched = self.batched(&df.schema())?; - batched.write_batch(df)?; + batched.write_batch(&chunked_df)?; batched.finish() } } diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs index 4da3cc660b6e..48aec098702a 100644 --- a/crates/polars-io/src/predicates.rs +++ b/crates/polars-io/src/predicates.rs @@ -168,17 +168,28 @@ impl ColumnStats { pub struct BatchStats { schema: SchemaRef, stats: Vec, + // This might not be available, + // as when prunnign hive partitions. + num_rows: Option, } impl BatchStats { - pub fn new(schema: SchemaRef, stats: Vec) -> Self { - Self { schema, stats } + pub fn new(schema: SchemaRef, stats: Vec, num_rows: Option) -> Self { + Self { + schema, + stats, + num_rows, + } } pub fn get_stats(&self, column: &str) -> polars_core::error::PolarsResult<&ColumnStats> { self.schema.try_index_of(column).map(|i| &self.stats[i]) } + pub fn num_rows(&self) -> Option { + self.num_rows + } + pub fn schema(&self) -> &SchemaRef { &self.schema } diff --git a/crates/polars-io/src/utils.rs b/crates/polars-io/src/utils.rs index 226bfd5a912a..114fc979b6eb 100644 --- a/crates/polars-io/src/utils.rs +++ b/crates/polars-io/src/utils.rs @@ -2,19 +2,10 @@ use std::io::Read; use std::path::{Path, PathBuf}; use once_cell::sync::Lazy; -#[cfg(any(feature = "csv", feature = "json"))] -use polars_core::frame::DataFrame; use polars_core::prelude::*; use regex::{Regex, RegexBuilder}; use crate::mmap::{MmapBytesReader, ReaderBytes}; -#[cfg(any( - feature = "ipc", - feature = "ipc_streaming", - feature = "parquet", - feature = "avro" -))] -use crate::ArrowSchema; pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( reader: &'a mut R, @@ -174,13 +165,13 @@ pub(crate) fn overwrite_schema( } pub static FLOAT_RE: Lazy = Lazy::new(|| { - Regex::new(r"^\s*[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$").unwrap() + Regex::new(r"^[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$").unwrap() }); -pub static INTEGER_RE: Lazy = Lazy::new(|| Regex::new(r"^\s*-?(\d+)$").unwrap()); +pub static INTEGER_RE: Lazy = Lazy::new(|| Regex::new(r"^-?(\d+)$").unwrap()); pub static BOOLEAN_RE: Lazy = Lazy::new(|| { - RegexBuilder::new(r"^\s*(true)$|^(false)$") + RegexBuilder::new(r"^(true|false)$") .case_insensitive(true) .build() .unwrap() diff --git a/crates/polars-json/src/ndjson/file.rs b/crates/polars-json/src/ndjson/file.rs index 0e47342274da..35700c1a6001 100644 --- a/crates/polars-json/src/ndjson/file.rs +++ b/crates/polars-json/src/ndjson/file.rs @@ -41,7 +41,7 @@ fn read_rows(reader: &mut R, rows: &mut [String], limit: usize) -> P /// /// This iterator is used to read chunks of an NDJSON in batches. /// This iterator is guaranteed to yield at least one row. -/// # Implementantion +/// # Implementation /// Advancing this iterator is IO-bounded, but does require parsing each byte to find end of lines. /// # Error /// Advancing this iterator errors iff the reader errors. diff --git a/crates/polars-json/src/ndjson/write.rs b/crates/polars-json/src/ndjson/write.rs index 10589cac3d80..90f202b02360 100644 --- a/crates/polars-json/src/ndjson/write.rs +++ b/crates/polars-json/src/ndjson/write.rs @@ -95,7 +95,7 @@ where /// /// There are two use-cases for this function: /// * to continue writing to its writer - /// * to re-use an internal buffer of its iterator + /// * to reuse an internal buffer of its iterator pub fn into_inner(self) -> (W, I) { (self.writer, self.iterator) } diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 1f859201e19a..b40a04a7a350 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -14,7 +14,7 @@ futures = { workspace = true, optional = true } polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } -polars-ops = { workspace = true } +polars-ops = { workspace = true, features = ["chunked_ids"] } polars-pipe = { workspace = true, optional = true } polars-plan = { workspace = true } polars-time = { workspace = true, optional = true } @@ -37,7 +37,7 @@ version_check = { workspace = true } [features] nightly = ["polars-core/nightly", "polars-pipe?/nightly", "polars-plan/nightly"] -streaming = ["chunked_ids", "polars-pipe", "polars-plan/streaming", "polars-ops/chunked_ids"] +streaming = ["polars-pipe", "polars-plan/streaming", "polars-ops/chunked_ids"] parquet = ["polars-io/parquet", "polars-plan/parquet", "polars-pipe?/parquet"] async = [ "polars-plan/async", @@ -47,7 +47,7 @@ async = [ cloud = ["async", "polars-pipe?/cloud", "polars-plan/cloud", "tokio", "futures"] cloud_write = ["cloud"] ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe?/ipc"] -json = ["polars-io/json", "polars-plan/json", "polars-json", "polars-pipe/json"] +json = ["polars-io/json", "polars-plan/json", "polars-json", "polars-pipe?/json"] csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe?/csv"] temporal = [ "dtype-datetime", @@ -119,8 +119,8 @@ unique_counts = ["polars-plan/unique_counts"] log = ["polars-plan/log"] list_eval = [] cumulative_eval = [] -chunked_ids = ["polars-plan/chunked_ids", "polars-core/chunked_ids", "polars-ops/chunked_ids"] list_to_struct = ["polars-plan/list_to_struct"] +array_to_struct = ["polars-plan/array_to_struct"] python = ["pyo3", "polars-plan/python", "polars-core/python", "polars-io/python"] row_hash = ["polars-plan/row_hash"] reinterpret = ["polars-plan/reinterpret", "polars-ops/reinterpret"] @@ -209,7 +209,6 @@ features = [ "async", "bigidx", "binary_encoding", - "chunked_ids", "cloud", "cloud_write", "coalesce", @@ -244,7 +243,6 @@ features = [ "fused", "futures", "hist", - "horizontal_concat", "interpolate", "ipc", "is_first_distinct", @@ -273,7 +271,6 @@ features = [ "peaks", "pivot", "polars-json", - "polars-pipe", "polars-time", "propagate_nans", "random", diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index 95dbf6b5f97d..2eae44388117 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -82,7 +82,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { .map(|len| { let s = s.slice(0, len); if (len - s.null_count()) >= min_periods { - let df = DataFrame::new_no_checks(vec![s]); + let df = s.into_frame(); let out = phys_expr.evaluate(&df, &state)?; finish(out) } else { @@ -91,7 +91,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { }) .collect::>>()? } else { - let mut df_container = DataFrame::new_no_checks(vec![]); + let mut df_container = DataFrame::empty(); (1..s.len() + 1) .map(|len| { let s = s.slice(0, len); diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index 495e49422bb9..5642c02ddf12 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -32,7 +32,7 @@ pub(crate) fn concat_impl>( }; let lf = match &mut lf.logical_plan { - // re-use the same union + // reuse the same union LogicalPlan::Union { inputs: existing_inputs, options: opts, @@ -146,7 +146,12 @@ pub fn concat_lf_diagonal>( .iter() // Zip Frames with their Schemas .zip(schemas) - .map(|(lf, lf_schema)| { + .filter_map(|(lf, lf_schema)| { + if lf_schema.is_empty() { + // if the frame is empty we discard + return None; + }; + let mut lf = lf.clone(); for (name, dtype) in total_schema.iter() { // If a name from Total Schema is not present - append @@ -162,7 +167,7 @@ pub fn concat_lf_diagonal>( .map(|col_name| col(col_name)) .collect::>(), ); - Ok(reordered_lf) + Some(Ok(reordered_lf)) }) .collect::>>()?; diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index f2fab52272f2..9d353a25c052 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -61,7 +61,7 @@ fn run_per_sublist( .par_iter() .map(|opt_s| { opt_s.and_then(|s| { - let df = DataFrame::new_no_checks(vec![s]); + let df = s.into_frame(); let out = phys_expr.evaluate(&df, &state); match out { Ok(s) => Some(s), @@ -76,7 +76,7 @@ fn run_per_sublist( err = m_err.into_inner().unwrap(); ca } else { - let mut df_container = DataFrame::new_no_checks(vec![]); + let mut df_container = DataFrame::empty(); lst.into_iter() .map(|s| { @@ -120,10 +120,11 @@ fn run_on_group_by_engine( // List elements in a series. let values = Series::try_from(("", arr.values().clone())).unwrap(); let inner_dtype = lst.inner_dtype(); - // Ensure we use the logical type. - let values = values.cast(&inner_dtype).unwrap(); + // SAFETY: + // Invariant in List means values physicals can be cast to inner dtype + let values = unsafe { values.cast_unchecked(&inner_dtype).unwrap() }; - let df_context = DataFrame::new_no_checks(vec![values]); + let df_context = values.into_frame(); let phys_expr = prepare_expression_for_context("", expr, &inner_dtype, Context::Aggregation)?; let state = ExecutionState::new(); diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 278590eacf34..afd965e663bf 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -19,7 +19,6 @@ use std::path::PathBuf; use std::sync::Arc; pub use anonymous_scan::*; -use arrow::legacy::prelude::QuantileInterpolOptions; #[cfg(feature = "csv")] pub use csv::*; #[cfg(not(target_arch = "wasm32"))] @@ -31,20 +30,10 @@ pub use ipc::*; pub use ndjson::*; #[cfg(feature = "parquet")] pub use parquet::*; -use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; use polars_io::RowIndex; pub use polars_plan::frame::{AllowedOptimizations, OptState}; use polars_plan::global::FETCH_ROWS; -#[cfg(any( - feature = "ipc", - feature = "parquet", - feature = "csv", - feature = "json" -))] -use polars_plan::logical_plan::collect_fingerprints; -use polars_plan::logical_plan::optimize; -use polars_plan::utils::expr_output_name; use smartstring::alias::String as SmartString; use crate::fallible; @@ -70,6 +59,12 @@ impl IntoLazy for DataFrame { } } +impl IntoLazy for LazyFrame { + fn lazy(self) -> LazyFrame { + self + } +} + /// Lazy abstraction over an eager `DataFrame`. /// It really is an abstraction over a logical plan. The methods of this struct will incrementally /// modify a logical plan until output is requested (via [`collect`](crate::frame::LazyFrame::collect)). @@ -208,10 +203,12 @@ impl LazyFrame { self.logical_plan.describe() } - /// Return a String describing the optimized logical plan. - /// - /// Returns `Err` if optimizing the logical plan fails. - pub fn describe_optimized_plan(&self) -> PolarsResult { + /// Return a String describing the naive (un-optimized) logical plan in tree format. + pub fn describe_plan_tree(&self) -> String { + self.logical_plan.describe_tree_format() + } + + fn optimized_plan(&self) -> PolarsResult { let mut expr_arena = Arena::with_capacity(64); let mut lp_arena = Arena::with_capacity(64); let lp_top = self.clone().optimize_with_scratch( @@ -220,8 +217,21 @@ impl LazyFrame { &mut vec![], true, )?; - let logical_plan = node_to_lp(lp_top, &expr_arena, &mut lp_arena); - Ok(logical_plan.describe()) + Ok(node_to_lp(lp_top, &expr_arena, &mut lp_arena)) + } + + /// Return a String describing the optimized logical plan. + /// + /// Returns `Err` if optimizing the logical plan fails. + pub fn describe_optimized_plan(&self) -> PolarsResult { + Ok(self.optimized_plan()?.describe()) + } + + /// Return a String describing the optimized logical plan in tree format. + /// + /// Returns `Err` if optimizing the logical plan fails. + pub fn describe_optimized_plan_tree(&self) -> PolarsResult { + Ok(self.optimized_plan()?.describe_tree_format()) } /// Return a String describing the logical plan. @@ -502,7 +512,12 @@ impl LazyFrame { } }) .collect(); - self.with_columns(cast_cols) + + if cast_cols.is_empty() { + self.clone() + } else { + self.with_columns(cast_cols) + } } /// Cast all frame columns to the given dtype, resulting in a new LazyFrame @@ -1087,7 +1102,7 @@ impl LazyFrame { ) } - /// Creates the cartesian product from both frames, preserving the order of the left keys. + /// Creates the Cartesian product from both frames, preserving the order of the left keys. #[cfg(feature = "cross_join")] pub fn cross_join(self, other: LazyFrame) -> LazyFrame { self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross)) @@ -1414,7 +1429,13 @@ impl LazyFrame { /// - String columns will sum to None. pub fn median(self) -> PolarsResult { self.stats_helper( - |dt| dt.is_numeric() || matches!(dt, DataType::Boolean | DataType::Datetime(_, _)), + |dt| { + dt.is_numeric() + || matches!( + dt, + DataType::Boolean | DataType::Duration(_) | DataType::Datetime(_, _) + ) + }, |name| col(name).median(), ) } diff --git a/crates/polars-lazy/src/frame/pivot.rs b/crates/polars-lazy/src/frame/pivot.rs index c9e0339593db..e7254ea0d908 100644 --- a/crates/polars-lazy/src/frame/pivot.rs +++ b/crates/polars-lazy/src/frame/pivot.rs @@ -31,11 +31,11 @@ impl PhysicalAggExpr for PivotExpr { } } -pub fn pivot( +pub fn pivot( df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_expr: Option, // used as separator/delimiter in generated column names. @@ -43,10 +43,10 @@ pub fn pivot( ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { // make sure that the root column is replaced @@ -56,20 +56,20 @@ where }); polars_ops::pivot::pivot( df, - values, index, columns, + values, sort_columns, agg_expr, separator, ) } -pub fn pivot_stable( +pub fn pivot_stable( df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_expr: Option, // used as separator/delimiter in generated column names. @@ -77,10 +77,10 @@ pub fn pivot_stable( ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { // make sure that the root column is replaced @@ -90,9 +90,9 @@ where }); polars_ops::pivot::pivot_stable( df, - values, index, columns, + values, sort_columns, agg_expr, separator, diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs index a995d248b25c..8d758ffc9f4e 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs @@ -1,8 +1,3 @@ -#[cfg(feature = "dynamic_group_by")] -use polars_core::frame::group_by::GroupBy; -#[cfg(feature = "dynamic_group_by")] -use polars_time::DynamicGroupOptions; - use super::*; #[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs index f99aa2cd618e..8fda5dd02c5a 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs @@ -1,5 +1,4 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -use polars_core::POOL; use rayon::prelude::*; use super::*; @@ -149,7 +148,7 @@ fn estimate_unique_count(keys: &[Series], mut sample_size: usize) -> PolarsResul .iter() .map(|s| s.slice(offset, sample_size)) .collect::>(); - let df = DataFrame::new_no_checks(keys); + let df = unsafe { DataFrame::new_no_checks(keys) }; let names = df.get_column_names(); let gb = df.group_by(names).unwrap(); Ok(finish(gb.get_groups())) diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs index cc2f1e2b677c..74e21c4d11aa 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs @@ -1,8 +1,3 @@ -#[cfg(feature = "dynamic_group_by")] -use polars_core::frame::group_by::GroupBy; -#[cfg(feature = "dynamic_group_by")] -use polars_time::RollingGroupOptions; - use super::*; #[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] diff --git a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs index 447a22c9f6cf..14731d2a0adb 100644 --- a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs +++ b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs @@ -53,6 +53,54 @@ fn rolling_evaluate( }) } +fn window_evaluate( + df: &DataFrame, + state: &ExecutionState, + window: PlHashMap>, +) -> PolarsResult>> { + POOL.install(|| { + window + .par_iter() + .map(|(_, partition)| { + // clear the cache for every partitioned group + let mut state = state.split(); + // inform the expression it has window functions. + state.insert_has_window_function_flag(); + + // don't bother caching if we only have a single window function in this partition + if partition.len() == 1 { + state.remove_cache_window_flag(); + } else { + state.insert_cache_window_flag(); + } + + let mut out = Vec::with_capacity(partition.len()); + // Don't parallelize here, as this will hold a mutex and Deadlock. + for (index, e) in partition { + if e.as_expression() + .unwrap() + .into_iter() + .filter(|e| matches!(e, Expr::Window { .. })) + .count() + == 1 + { + state.insert_cache_window_flag(); + } + // caching more than one window expression is a complicated topic for another day + // see issue #2523 + else { + state.remove_cache_window_flag(); + } + + let s = e.evaluate(df, &state)?; + out.push((*index, s)); + } + Ok(out) + }) + .collect() + }) +} + fn execute_projection_cached_window_fns( df: &DataFrame, exprs: &[Arc], @@ -116,43 +164,25 @@ fn execute_projection_cached_window_fns( // The rolling expression knows how to fetch the groups. #[cfg(feature = "dynamic_group_by")] { - let partitions = rolling_evaluate(df, state, rolling)?; + let (a, b) = POOL.join( + || rolling_evaluate(df, state, rolling), + || window_evaluate(df, state, windows), + ); + + let partitions = a?; for part in partitions { selected_columns.extend_from_slice(&part) } - } - - for partition in windows { - // clear the cache for every partitioned group - let mut state = state.split(); - // inform the expression it has window functions. - state.insert_has_window_function_flag(); - - // don't bother caching if we only have a single window function in this partition - if partition.1.len() == 1 { - state.remove_cache_window_flag(); - } else { - state.insert_cache_window_flag(); + let partitions = b?; + for part in partitions { + selected_columns.extend_from_slice(&part) } - - for (index, e) in partition.1 { - if e.as_expression() - .unwrap() - .into_iter() - .filter(|e| matches!(e, Expr::Window { .. })) - .count() - == 1 - { - state.insert_cache_window_flag(); - } - // caching more than one window expression is a complicated topic for another day - // see issue #2523 - else { - state.remove_cache_window_flag(); - } - - let s = e.evaluate(df, &state)?; - selected_columns.push((index, s)); + } + #[cfg(not(feature = "dynamic_group_by"))] + { + let partitions = window_evaluate(df, state, windows)?; + for part in partitions { + selected_columns.extend_from_slice(&part) } } @@ -254,7 +284,17 @@ pub(super) fn check_expand_literals( all_equal_len = false; } let name = s.name(); - polars_ensure!(names.insert(name), duplicate = name); + + if !names.insert(name) { + let msg = format!( + "the name: '{}' is duplicate\n\n\ + It's possible that multiple expressions are returning the same default column \ + name. If this is the case, try renaming the columns with \ + `.alias(\"new_name\")` to avoid duplicate column names.", + name + ); + return Err(PolarsError::Duplicate(msg.into())); + } } } // If all series are the same length it is ok. If not we can broadcast Series of length one. @@ -276,7 +316,7 @@ pub(super) fn check_expand_literals( .collect::>()? } - let df = DataFrame::new_no_checks(selected_columns); + let df = unsafe { DataFrame::new_no_checks(selected_columns) }; // a literal could be projected to a zero length dataframe. // This prevents a panic. diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs index 856e7ba7bfec..bee325fc69d3 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs @@ -45,6 +45,7 @@ impl CsvExec { .with_rechunk(self.file_options.rechunk) .with_row_index(std::mem::take(&mut self.file_options.row_index)) .with_try_parse_dates(self.options.try_parse_dates) + .with_n_threads(self.options.n_threads) .truncate_ragged_lines(self.options.truncate_ragged_lines) .raise_if_empty(self.options.raise_if_empty) .finish() diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs index 9e8101052a4a..40b66fef3ae6 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs @@ -1,5 +1,4 @@ use super::*; -use crate::prelude::{AnonymousScan, LazyJsonLineReader}; impl AnonymousScan for LazyJsonLineReader { fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 780eaea8fec4..7c0665d3be88 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -4,7 +4,6 @@ use std::path::PathBuf; use polars_core::config::{get_file_prefetch_size, verbose}; use polars_core::utils::accumulate_dataframes_vertical; use polars_io::cloud::CloudOptions; -use polars_io::parquet::FileMetaData; use polars_io::{is_cloud_url, RowIndex}; use super::*; diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index bef1e42b5fc4..138f1f566eaa 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -31,7 +31,8 @@ pub(crate) fn prepare_expression_for_context( // create a dummy lazyframe and run a very simple optimization run so that // type coercion and simplify expression optimizations run. let column = Series::full_null(name, 0, dtype); - let lf = DataFrame::new_no_checks(vec![column]) + let lf = column + .into_frame() .lazy() .without_optimizations() .with_simplify_expr(true) diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index 38064189eaa9..5393d2a6a4a0 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -1,12 +1,9 @@ use std::borrow::Cow; -use std::sync::Arc; use arrow::array::*; use arrow::compute::concatenate::concatenate; -use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::legacy::utils::CustomIterTools; use arrow::offset::Offsets; -use polars_core::frame::group_by::{GroupByMethod, GroupsProxy}; use polars_core::prelude::*; use polars_core::utils::NoNull; #[cfg(feature = "dtype-struct")] @@ -15,7 +12,6 @@ use polars_core::POOL; use polars_ops::prelude::nan_propagating_aggregate; use crate::physical_plan::state::ExecutionState; -use crate::physical_plan::PartitionedAggregation; use crate::prelude::AggState::{AggregatedList, AggregatedScalar}; use crate::prelude::*; @@ -64,7 +60,7 @@ impl PhysicalExpr for AggregationExpr { } } - // Safety: + // SAFETY: // groups must always be in bounds. let out = unsafe { match self.agg_type { @@ -335,7 +331,7 @@ impl PartitionedAggregation for AggregationExpr { let expr = self.input.as_partitioned_aggregator().unwrap(); let series = expr.evaluate_partitioned(df, groups, state)?; - // Safety: + // SAFETY: // groups are in bounds unsafe { match self.agg_type { @@ -476,7 +472,7 @@ impl PartitionedAggregation for AggregationExpr { GroupsProxy::Idx(groups) => { for (_, idx) in groups { let ca = unsafe { - // Safety + // SAFETY: // The indexes of the group_by operation are never out of bounds ca.take_unchecked(idx) }; @@ -586,7 +582,7 @@ impl PhysicalExpr for AggQuantileExpr { let quantile = self.get_quantile(df, state)?; - // safety: + // SAFETY: // groups are in bounds let mut agg = unsafe { ac.flat_naive() diff --git a/crates/polars-lazy/src/physical_plan/expressions/alias.rs b/crates/polars-lazy/src/physical_plan/expressions/alias.rs index 44a84e96ddb5..c715083b01f4 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/alias.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/alias.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index ea664b4dafcd..1bf465412894 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -1,13 +1,11 @@ use std::borrow::Cow; -use std::sync::Arc; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "parquet")] use polars_io::predicates::{BatchStats, StatsEvaluator}; -#[cfg(feature = "parquet")] -use polars_plan::dsl::FunctionExpr; +#[cfg(feature = "is_between")] +use polars_ops::prelude::ClosedInterval; use rayon::prelude::*; use crate::physical_plan::state::ExecutionState; @@ -386,6 +384,9 @@ impl PhysicalExpr for ApplyExpr { FunctionExpr::Boolean(BooleanFunction::IsNull) => Some(self), #[cfg(feature = "is_in")] FunctionExpr::Boolean(BooleanFunction::IsIn) => Some(self), + #[cfg(feature = "is_between")] + FunctionExpr::Boolean(BooleanFunction::IsBetween { closed: _ }) => Some(self), + FunctionExpr::Boolean(BooleanFunction::IsNotNull) => Some(self), _ => None, } } @@ -498,6 +499,23 @@ impl ApplyExpr { None => Ok(true), } }, + FunctionExpr::Boolean(BooleanFunction::IsNotNull) => { + let root = expr_to_leaf_column_name(&self.expr)?; + + match stats.get_stats(&root).ok() { + Some(st) => match st.null_count() { + Some(null_count) + if stats + .num_rows() + .map_or(false, |num_rows| num_rows == null_count) => + { + Ok(false) + }, + _ => Ok(true), + }, + None => Ok(true), + } + }, #[cfg(feature = "is_in")] FunctionExpr::Boolean(BooleanFunction::IsIn) => { let should_read = || -> Option { @@ -517,9 +535,67 @@ impl ApplyExpr { return one_equals(min); } - let all_smaller = || Some(ChunkCompare::lt(input, min).ok()?.all()); - let all_bigger = || Some(ChunkCompare::gt(input, max).ok()?.all()); - Some(!all_smaller()? && !all_bigger()?) + let smaller = ChunkCompare::lt(input, min).ok()?; + let bigger = ChunkCompare::gt(input, max).ok()?; + + Some(!(smaller | bigger).all()) + }; + + Ok(should_read().unwrap_or(true)) + }, + #[cfg(feature = "is_between")] + FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }) => { + let should_read = || -> Option { + let root: Arc = expr_to_leaf_column_name(&input[0]).ok()?; + let Expr::Literal(left) = &input[1] else { + return None; + }; + let Expr::Literal(right) = &input[2] else { + return None; + }; + + let st = stats.get_stats(&root).ok()?; + let min = st.to_min()?; + let max = st.to_max()?; + + let (left, left_dtype) = (left.to_any_value()?, left.get_datatype()); + let (right, right_dtype) = (right.to_any_value()?, right.get_datatype()); + + let left = + Series::from_any_values_and_dtype("", &[left], &left_dtype, false).ok()?; + let right = + Series::from_any_values_and_dtype("", &[right], &right_dtype, false) + .ok()?; + + // don't read the row_group anyways as + // the condition will evaluate to false. + // e.g. in_between(10, 5) + if ChunkCompare::gt(&left, &right).ok()?.all() { + return Some(false); + } + + let (left_open, right_open) = match closed { + ClosedInterval::None => (true, true), + ClosedInterval::Both => (false, false), + ClosedInterval::Left => (false, true), + ClosedInterval::Right => (true, false), + }; + // check the right limit of the interval. + // if the end is open, we should be stricter (lt_eq instead of lt). + if right_open && ChunkCompare::lt_eq(&right, min).ok()?.all() + || !right_open && ChunkCompare::lt(&right, min).ok()?.all() + { + return Some(false); + } + // we couldn't conclude anything using the right limit, + // check the left limit of the interval + if left_open && ChunkCompare::gt_eq(&left, max).ok()?.all() + || !left_open && ChunkCompare::gt(&left, max).ok()?.all() + { + return Some(false); + } + // read the row_group + Some(true) }; Ok(should_read().unwrap_or(true)) diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index c244c0f9bb00..f3b3d4e2f51b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "round_series")] @@ -397,7 +394,7 @@ mod stats { } } - let dummy = DataFrame::new_no_checks(vec![]); + let dummy = DataFrame::empty(); let state = ExecutionState::new(); let out = match (self.left.is_literal(), self.right.is_literal()) { diff --git a/crates/polars-lazy/src/physical_plan/expressions/cast.rs b/crates/polars-lazy/src/physical_plan/expressions/cast.rs index 962ac5086a86..32ad204ba867 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/cast.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/cast.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/column.rs b/crates/polars-lazy/src/physical_plan/expressions/column.rs index d4acf8a309bc..eda761ab9d56 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/column.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/column.rs @@ -1,7 +1,5 @@ use std::borrow::Cow; -use std::sync::Arc; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_plan::constants::CSE_REPLACED; diff --git a/crates/polars-lazy/src/physical_plan/expressions/filter.rs b/crates/polars-lazy/src/physical_plan/expressions/filter.rs index e6adb24953e8..b2cfe43e3997 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/filter.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/filter.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - use arrow::legacy::is_valid::IsValid; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; use polars_utils::idx_vec::IdxVec; diff --git a/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs b/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs index 72b8d1420f73..3949b7c37961 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs @@ -17,7 +17,7 @@ impl<'a> AggregationContext<'a> { self.groups(); let s = self.series().rechunk(); let name = if keep_names { s.name() } else { "" }; - // safety: dtype is correct + // SAFETY: dtype is correct unsafe { Box::new(LitIter::new( s.array_ref(0).clone(), @@ -31,7 +31,7 @@ impl<'a> AggregationContext<'a> { self.groups(); let s = self.series(); let name = if keep_names { s.name() } else { "" }; - // safety: dtype is correct + // SAFETY: dtype is correct unsafe { Box::new(FlatIter::new( s.array_ref(0).clone(), @@ -83,7 +83,7 @@ impl<'a> LitIter<'a> { offset: 0, len, series_container, - // Safety: we pinned the series so the location is still valid + // SAFETY: we pinned the series so the location is still valid item: UnstableSeries::new(unsafe { &mut *ref_s }), } } @@ -131,7 +131,7 @@ impl<'a> FlatIter<'a> { offset: 0, len, series_container, - // Safety: we pinned the series so the location is still valid + // SAFETY: we pinned the series so the location is still valid item: UnstableSeries::new(unsafe { &mut *ref_s }), } } diff --git a/crates/polars-lazy/src/physical_plan/expressions/literal.rs b/crates/polars-lazy/src/physical_plan/expressions/literal.rs index a0618b13751c..cf33aa81c9a7 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/literal.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/literal.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use std::ops::Deref; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_plan::dsl::consts::LITERAL_NAME; diff --git a/crates/polars-lazy/src/physical_plan/expressions/mod.rs b/crates/polars-lazy/src/physical_plan/expressions/mod.rs index e13c81d01727..4642654a9fb6 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/mod.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/mod.rs @@ -31,7 +31,6 @@ pub(crate) use column::*; pub(crate) use count::*; pub(crate) use filter::*; pub(crate) use literal::*; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_io::predicates::PhysicalIoExpr; #[cfg(feature = "dynamic_group_by")] @@ -343,7 +342,7 @@ impl<'a> AggregationContext<'a> { AggState::NotAggregated(s) => { // We should not aggregate literals!! if self.state.safe_to_agg(&self.groups) { - // safety: + // SAFETY: // groups are in bounds let agg = unsafe { s.agg_list(&self.groups) }; self.update_groups = UpdateGroups::WithGroupsLen; @@ -445,7 +444,7 @@ impl<'a> AggregationContext<'a> { } } - // safety: + // SAFETY: // groups are in bounds let out = unsafe { s.agg_list(&self.groups) }; self.state = AggState::AggregatedList(out.clone()); diff --git a/crates/polars-lazy/src/physical_plan/expressions/slice.rs b/crates/polars-lazy/src/physical_plan/expressions/slice.rs index 13793e55ac34..3d0129675a96 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/slice.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/slice.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::{GroupsProxy, IdxItem}; use polars_core::prelude::*; use polars_core::utils::{slice_offsets, CustomIterTools}; use polars_core::POOL; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 77709c9d8a03..0df7d4b94ab9 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; use polars_ops::chunked_array::ListNameSpaceImpl; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index b6ef95d46a00..d213af7631ed 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::{GroupsIndicator, GroupsProxy}; use polars_core::prelude::*; use polars_core::POOL; use polars_utils::idx_vec::IdxVec; diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index b6b20ff5830e..9408635de332 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -1,8 +1,5 @@ -use std::sync::Arc; - use arrow::legacy::utils::CustomIterTools; use polars_core::chunked_array::builder::get_list_builder; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain}; diff --git a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs index aed1b74cf710..d52cb4eb8d61 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index 3ccac3a9a00d..5e60f43923f7 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -1,16 +1,12 @@ use std::fmt::Write; -use std::sync::Arc; use arrow::array::PrimitiveArray; use polars_core::export::arrow::bitmap::Bitmap; -use polars_core::frame::group_by::{GroupBy, GroupsProxy}; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::_split_offsets; use polars_core::{downcast_as_macro_arg_physical, POOL}; -use polars_ops::frame::join::{ - default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, -}; +use polars_ops::frame::join::{default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds}; use polars_ops::frame::SeriesJoin; use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; @@ -18,7 +14,6 @@ use polars_utils::sync::SyncPtr; use rayon::prelude::*; use super::*; -use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub struct WindowExpr { @@ -114,12 +109,12 @@ impl WindowExpr { take_idx = original_idx; } cache_gb(gb, state, cache_key); - // Safety: + // SAFETY: // we only have unique indices ranging from 0..len unsafe { perfect_sort(&POOL, &idx_mapping, &mut take_idx) }; let idx = IdxCa::from_vec("", take_idx); - // Safety: + // SAFETY: // groups should always be in bounds. unsafe { Ok(flattened.take_unchecked(&idx)) } } @@ -567,8 +562,8 @@ impl PhysicalExpr for WindowExpr { .unwrap() .1 } else { - let df_right = DataFrame::new_no_checks(keys); - let df_left = DataFrame::new_no_checks(group_by_columns); + let df_right = unsafe { DataFrame::new_no_checks(keys) }; + let df_left = unsafe { DataFrame::new_no_checks(group_by_columns) }; private_left_join_multiple_keys( &df_left, &df_right, None, None, true, ) @@ -629,22 +624,17 @@ impl PhysicalExpr for WindowExpr { } fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Series { - #[cfg(feature = "chunked_ids")] { use arrow::Either; + use polars_ops::chunked_array::TakeChunked; match join_opt_ids { Either::Left(ids) => unsafe { out_column.take_unchecked(&ids.iter().copied().collect_ca("")) }, - Either::Right(ids) => unsafe { out_column._take_opt_chunked_unchecked(ids) }, + Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids) }, } } - - #[cfg(not(feature = "chunked_ids"))] - unsafe { - out_column.take_unchecked(&join_opt_ids.iter().copied().collect_ca("")) - } } fn cache_gb(gb: GroupBy, state: &ExecutionState, cache_key: &str) { @@ -685,23 +675,9 @@ where T: PolarsNumericType, ChunkedArray: IntoSeries, { - let mut idx_mapping = Vec::with_capacity(len); - let mut iter = 0..len as IdxSize; - match groups { - GroupsProxy::Idx(groups) => { - for g in groups.all() { - idx_mapping.extend((&mut iter).take(g.len()).zip(g.iter().copied())); - } - }, - GroupsProxy::Slice { groups, .. } => { - for &[first, len] in groups { - idx_mapping.extend((&mut iter).take(len as usize).zip(first..first + len)); - } - }, - } let mut values = Vec::with_capacity(len); let ptr: *mut T::Native = values.as_mut_ptr(); - // safety: + // SAFETY: // we will write from different threads but we will never alias. let sync_ptr_values = unsafe { SyncPtr::new(ptr) }; @@ -743,7 +719,7 @@ where }, } - // safety: we have written all slots + // SAFETY: we have written all slots unsafe { values.set_len(len) } Some(ChunkedArray::new_vec(ca.name(), values).into_series()) } else { @@ -765,7 +741,7 @@ where let values_ptr = sync_ptr_values.get(); let validity_ptr = sync_ptr_validity.get(); - ca.into_iter().zip(groups.iter()).for_each(|(opt_v, g)| { + ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| { for idx in g.as_slice() { let idx = *idx as usize; debug_assert!(idx < len); @@ -793,7 +769,7 @@ where let values_ptr = sync_ptr_values.get(); let validity_ptr = sync_ptr_validity.get(); - for (opt_v, [start, g_len]) in ca.into_iter().zip(groups.iter()) { + for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) { let start = *start as usize; let end = start + *g_len as usize; for idx in start..end { @@ -815,7 +791,7 @@ where }) }, } - // safety: we have written all slots + // SAFETY: we have written all slots unsafe { values.set_len(len) } unsafe { validity.set_len(len) } let validity = Bitmap::from(validity); diff --git a/crates/polars-lazy/src/physical_plan/file_cache.rs b/crates/polars-lazy/src/physical_plan/file_cache.rs index 5ea1074d95b0..ee7c8e8ddffa 100644 --- a/crates/polars-lazy/src/physical_plan/file_cache.rs +++ b/crates/polars-lazy/src/physical_plan/file_cache.rs @@ -1,13 +1,6 @@ use std::sync::Mutex; use polars_core::prelude::*; -#[cfg(any( - feature = "parquet", - feature = "csv", - feature = "ipc", - feature = "json" -))] -use polars_plan::logical_plan::FileFingerPrint; use crate::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/node_timer.rs b/crates/polars-lazy/src/physical_plan/node_timer.rs index 8be6861dda39..4926f7df8c59 100644 --- a/crates/polars-lazy/src/physical_plan/node_timer.rs +++ b/crates/polars-lazy/src/physical_plan/node_timer.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use std::time::Instant; use polars_core::prelude::*; @@ -57,10 +57,8 @@ impl NodeTimer { let mut end = end.into_inner(); end.rename("end"); - DataFrame::new_no_checks(vec![nodes_s, start.into_series(), end.into_series()]).sort( - vec!["start"], - vec![false], - false, - ) + let columns = vec![nodes_s, start.into_series(), end.into_series()]; + let df = unsafe { DataFrame::new_no_checks(columns) }; + df.sort(vec!["start"], vec![false], false) } } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 0489bf40c257..26e5c920ca6c 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -1,4 +1,3 @@ -use polars_core::frame::group_by::GroupByMethod; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::_split_offsets; diff --git a/crates/polars-lazy/src/physical_plan/state.rs b/crates/polars-lazy/src/physical_plan/state.rs index 5112e39eb2db..c946399d5017 100644 --- a/crates/polars-lazy/src/physical_plan/state.rs +++ b/crates/polars-lazy/src/physical_plan/state.rs @@ -5,7 +5,6 @@ use std::sync::{Mutex, RwLock}; use bitflags::bitflags; use once_cell::sync::OnceCell; use polars_core::config::verbose; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_ops::prelude::ChunkJoinOptIds; #[cfg(any( diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 506760c9744d..80c81cb16815 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -1,7 +1,6 @@ use std::any::Any; use std::cell::RefCell; use std::rc::Rc; -use std::sync::Arc; use polars_core::config::verbose; use polars_core::prelude::*; @@ -132,7 +131,7 @@ pub(super) fn construct( let mut final_sink = None; for branch in tree { - // the file sink is always to the top of the tree + // The file sink is always to the top of the tree // not every branch has a final sink. For instance rhs join branches if let Some(node) = branch.get_final_sink() { if matches!(lp_arena.get(node), ALogicalPlan::Sink { .. }) { diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index f5bb0f50e3a4..5bdb94b2b2e6 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -1,4 +1,3 @@ -use polars_core::error::PolarsResult; use polars_core::prelude::*; use polars_pipe::pipeline::swap_join_order; use polars_plan::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/streaming/tree.rs b/crates/polars-lazy/src/physical_plan/streaming/tree.rs index d948ab366405..00ae4dfb309d 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/tree.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/tree.rs @@ -1,11 +1,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; -#[cfg(debug_assertions)] use polars_plan::prelude::*; -#[cfg(debug_assertions)] -use polars_utils::arena::Arena; -use polars_utils::arena::Node; #[derive(Copy, Clone, Debug)] pub(super) enum PipelineNode { diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index 964ec7894a2d..5463d524ed4e 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -2,7 +2,8 @@ pub use polars_ops::prelude::{JoinArgs, JoinType, JoinValidation}; #[cfg(feature = "rank")] pub use polars_ops::prelude::{RankMethod, RankOptions}; pub use polars_plan::logical_plan::{ - AnonymousScan, AnonymousScanOptions, Literal, LiteralValue, LogicalPlan, Null, NULL, + AnonymousScan, AnonymousScanArgs, AnonymousScanOptions, Literal, LiteralValue, LogicalPlan, + Null, NULL, }; #[cfg(feature = "csv")] pub use polars_plan::prelude::CsvWriterOptions; diff --git a/crates/polars-lazy/src/scan/csv.rs b/crates/polars-lazy/src/scan/csv.rs index f32202ed82be..16b1317b2913 100644 --- a/crates/polars-lazy/src/scan/csv.rs +++ b/crates/polars-lazy/src/scan/csv.rs @@ -6,7 +6,6 @@ use polars_io::csv::{CommentPrefix, CsvEncoding, NullValues}; use polars_io::utils::get_reader_bytes; use polars_io::RowIndex; -use crate::frame::LazyFileListReader; use crate::prelude::*; #[derive(Clone)] @@ -36,6 +35,7 @@ pub struct LazyCsvReader<'a> { row_index: Option, try_parse_dates: bool, raise_if_empty: bool, + n_threads: Option, } #[cfg(feature = "csv")] @@ -70,6 +70,7 @@ impl<'a> LazyCsvReader<'a> { try_parse_dates: false, raise_if_empty: true, truncate_ragged_lines: false, + n_threads: None, } } @@ -233,7 +234,7 @@ impl<'a> LazyCsvReader<'a> { /// Modify a schema before we run the lazy scanning. /// /// Important! Run this function latest in the builder! - pub fn with_schema_modify(self, f: F) -> PolarsResult + pub fn with_schema_modify(mut self, f: F) -> PolarsResult where F: Fn(Schema) -> PolarsResult, { @@ -264,6 +265,7 @@ impl<'a> LazyCsvReader<'a> { None, self.try_parse_dates, self.raise_if_empty, + &mut self.n_threads, )?; let mut schema = f(schema)?; @@ -303,6 +305,7 @@ impl LazyFileListReader for LazyCsvReader<'_> { self.try_parse_dates, self.raise_if_empty, self.truncate_ragged_lines, + self.n_threads, )? .build() .into(); diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index c54e584f2731..361cb0589ecf 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -1,5 +1,5 @@ use polars_ops::prelude::ListNameSpaceImpl; -use polars_utils::idxvec; +use polars_utils::unitvec; use super::*; @@ -9,7 +9,7 @@ fn test_agg_list_type() -> PolarsResult<()> { let s = Series::new("foo", &[1, 2, 3]); let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; - let l = unsafe { s.agg_list(&GroupsProxy::Idx(vec![(0, idxvec![0, 1, 2])].into())) }; + let l = unsafe { s.agg_list(&GroupsProxy::Idx(vec![(0, unitvec![0, 1, 2])].into())) }; let result = match l.dtype() { DataType::List(inner) => { diff --git a/crates/polars-lazy/src/tests/err_msg.rs b/crates/polars-lazy/src/tests/err_msg.rs new file mode 100644 index 000000000000..5f73f1c30c9a --- /dev/null +++ b/crates/polars-lazy/src/tests/err_msg.rs @@ -0,0 +1,83 @@ +use polars_core::error::ErrString; + +use super::*; + +const INITIAL_PROJECTION_STR: &str = r#"DF ["c1"]; PROJECT */1 COLUMNS; SELECTION: "None""#; + +fn make_df() -> LazyFrame { + df! [ "c1" => [0, 1] ].unwrap().lazy() +} + +fn assert_errors_eq(e1: &PolarsError, e2: &PolarsError) { + use PolarsError::*; + match (e1, e2) { + (ColumnNotFound(s1), ColumnNotFound(s2)) => { + assert_eq!(s1.as_ref(), s2.as_ref()); + }, + (ComputeError(s1), ComputeError(s2)) => { + assert_eq!(s1.as_ref(), s2.as_ref()); + }, + _ => panic!("{e1:?} != {e2:?}"), + } +} + +#[test] +fn col_not_found_error_messages() { + fn get_err_msg(err_msg: &str, n: usize) -> String { + let plural_s; + let was_were; + + if n == 1 { + plural_s = ""; + was_were = "was" + } else { + plural_s = "s"; + was_were = "were"; + }; + format!( + "{err_msg}\n\nLogicalPlan had already failed with the above error; \ + after failure, {n} additional operation{plural_s} \ + {was_were} attempted on the LazyFrame" + ) + } + fn test_col_not_found(df: LazyFrame, n: usize) { + let err_msg = format!( + "xyz\n\nError originated just after this \ + operation:\n{INITIAL_PROJECTION_STR}" + ); + + let plan_err_str = + format!("ErrorState {{ n_times: {n}, err: ColumnNotFound(ErrString({err_msg:?})) }}"); + + let collect_err = if n == 0 { + PolarsError::ColumnNotFound(ErrString::from(err_msg.to_owned())) + } else { + PolarsError::ColumnNotFound(ErrString::from(get_err_msg(&err_msg, n))) + }; + + assert_eq!(df.describe_plan(), plan_err_str); + assert_errors_eq(&df.collect().unwrap_err(), &collect_err); + } + + let df = make_df(); + + assert_eq!(df.describe_plan(), INITIAL_PROJECTION_STR); + + test_col_not_found(df.clone().select([col("xyz")]), 0); + test_col_not_found(df.clone().select([col("xyz")]).select([col("c1")]), 1); + test_col_not_found( + df.clone() + .select([col("xyz")]) + .select([col("c1")]) + .select([col("c2")]), + 2, + ); + test_col_not_found( + df.clone() + .select([col("xyz")]) + .select([col("c1")]) + .select([col("c2")]) + .select([col("c3")]), + 3, + ); +} diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 5f30d038166d..0b22a1a33a4d 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -1,8 +1,11 @@ use polars_io::RowIndex; +#[cfg(feature = "is_between")] +use polars_ops::prelude::ClosedInterval; use super::*; #[test] +#[cfg(feature = "parquet")] fn test_parquet_exec() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); // filter @@ -34,6 +37,7 @@ fn test_parquet_exec() -> PolarsResult<()> { } #[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] fn test_parquet_statistics_no_skip() { let _guard = SINGLE_LOCK.lock().unwrap(); init_files(); @@ -62,6 +66,38 @@ fn test_parquet_statistics_no_skip() { .unwrap(); assert_eq!(out.shape(), (27, 4)); + // statistics and `is_between` + // normal case + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(40, 300, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (19, 4)); + // normal case + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(10, 50, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (11, 4)); + // edge case: 20 = min(calories) but the right end is closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::Right)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (1, 4)); + // edge case: 200 = max(calories) but the left end is closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::Left)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (3, 4)); + // edge case: left == right but both ends are closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 200, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (3, 4)); + // Or operation let out = scan_foods_parquet(par) .filter( @@ -75,6 +111,7 @@ fn test_parquet_statistics_no_skip() { } #[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] fn test_parquet_statistics() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); init_files(); @@ -97,11 +134,187 @@ fn test_parquet_statistics() -> PolarsResult<()> { .collect()?; assert_eq!(out.shape(), (0, 4)); + // issue: 13427 + let out = scan_foods_parquet(par) + .filter(col("calories").is_in(lit(Series::new("", [0, 500])))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // statistics and `is_between` + // 15 < min(calories)=20 + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 15, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 300 > max(calories)=200 + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(300, 500, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 20 == min(calories) but right end is open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::Left)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 20 == min(calories) but both ends are open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::None)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 200 == max(calories) but left end is open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::Right)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 200 == max(calories) but both ends are open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::None)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // between(100, 40) is impossible + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(100, 40, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // with strings + let out = scan_foods_parquet(par) + .filter(col("category").is_between(lit("yams"), lit("zest"), ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // with strings + let out = scan_foods_parquet(par) + .filter(col("category").is_between(lit("dairy"), lit("eggs"), ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + let out = scan_foods_parquet(par) .filter(lit(1000i32).lt(col("calories"))) .collect()?; assert_eq!(out.shape(), (0, 4)); + // not(a > b) => a <= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a >= b) => a < b + // note that min(calories)=20 + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt_eq(20))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a < b) => a >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").lt(250))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a <= b) => a > b + // note that max(calories)=200 + let out = scan_foods_parquet(par) + .filter(not(col("calories").lt_eq(200))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a == b) => a != b + // note that proteins_g=10 for all rows + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("proteins_g").eq(10))) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(a != b) => a == b + // note that proteins_g=10 for all rows + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("proteins_g").neq(5))) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(col(c) is between [a, b]) => col(c) < a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 20, + 200, + ClosedInterval::Both, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between [a, b[) => col(c) < a or col(c) >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 20, + 201, + ClosedInterval::Left, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b]) => col(c) <= a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 200, + ClosedInterval::Right, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b]) => col(c) <= a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 200, + ClosedInterval::Right, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b[) => col(c) <= a or col(c) >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 201, + ClosedInterval::None, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not (a or b) => not(a) and not(b) + // note that not(fats_g <= 9) is possible; not(calories > 5) should allow us skip the rg + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5).or(col("fats_g").lt_eq(9)))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not (a and b) => not(a) or not(b) + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5).and(col("fats_g").lt_eq(12)))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // is_not_null + let out = scan_nutri_score_null_column_parquet(par) + .filter(col("nutri_score").is_not_null()) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(is_null) (~pl.col('nutri_score').is_null()) + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("nutri_score").is_null())) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + // Test multiple predicates // And operation @@ -148,7 +361,7 @@ fn test_parquet_globbing() -> PolarsResult<()> { // for side effects init_files(); let _guard = SINGLE_LOCK.lock().unwrap(); - let glob = "../../examples/datasets/*.parquet"; + let glob = "../../examples/datasets/foods*.parquet"; let df = LazyFrame::scan_parquet( glob, ScanArgsParquet { @@ -194,7 +407,7 @@ fn test_scan_parquet_limit_9001() { fn test_ipc_globbing() -> PolarsResult<()> { // for side effects init_files(); - let glob = "../../examples/datasets/*.ipc"; + let glob = "../../examples/datasets/foods*.ipc"; let df = LazyFrame::scan_ipc( glob, ScanArgsIpc { @@ -226,7 +439,7 @@ fn slice_at_union(lp_arena: &Arena, lp: Node) -> bool { #[test] fn test_csv_globbing() -> PolarsResult<()> { - let glob = "../../examples/datasets/*.csv"; + let glob = "../../examples/datasets/foods*.csv"; let full_df = LazyCsvReader::new(glob).finish()?.collect()?; // all 5 files * 27 rows @@ -263,7 +476,7 @@ fn test_csv_globbing() -> PolarsResult<()> { fn test_ndjson_globbing() -> PolarsResult<()> { // for side effects init_files(); - let glob = "../../examples/datasets/*.ndjson"; + let glob = "../../examples/datasets/foods*.ndjson"; let df = LazyJsonLineReader::new(glob).finish()?.collect()?; assert_eq!(df.shape(), (54, 4)); let cal = df.column("calories")?; diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs index 3ddce27f213f..be7deb799468 100644 --- a/crates/polars-lazy/src/tests/mod.rs +++ b/crates/polars-lazy/src/tests/mod.rs @@ -2,6 +2,7 @@ mod aggregations; mod arity; #[cfg(all(feature = "strings", feature = "cse"))] mod cse; +mod err_msg; #[cfg(feature = "parquet")] mod io; mod logical; @@ -29,7 +30,6 @@ fn load_df() -> DataFrame { } use std::io::Cursor; -use std::iter::FromIterator; use optimization_checks::*; use polars_core::chunked_array::builder::get_list_builder; @@ -41,7 +41,7 @@ use polars_core::prelude::*; pub(crate) use polars_core::SINGLE_LOCK; use polars_io::prelude::*; use polars_plan::logical_plan::{ - ArenaLpIter, OptimizationRule, SimplifyExprRule, StackOptimizer, TypeCoercionRule, + OptimizationRule, SimplifyExprRule, StackOptimizer, TypeCoercionRule, }; #[cfg(feature = "cov")] @@ -56,6 +56,8 @@ static GLOB_CSV: &str = "../../examples/datasets/*.csv"; static GLOB_IPC: &str = "../../examples/datasets/*.ipc"; #[cfg(feature = "parquet")] static FOODS_PARQUET: &str = "../../examples/datasets/foods1.parquet"; +#[cfg(feature = "parquet")] +static NUTRI_SCORE_NULL_COLUMN_PARQUET: &str = "../../examples/datasets/null_nutriscore.parquet"; #[cfg(feature = "csv")] static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; #[cfg(feature = "ipc")] @@ -77,6 +79,7 @@ fn init_files() { for path in &[ "../../examples/datasets/foods1.csv", "../../examples/datasets/foods2.csv", + "../../examples/datasets/null_nutriscore.csv", ] { for ext in [".parquet", ".ipc", ".ndjson"] { let out_path = path.replace(".csv", ext); @@ -131,6 +134,26 @@ fn scan_foods_parquet(parallel: bool) -> LazyFrame { LazyFrame::scan_parquet(out_path, args).unwrap() } +#[cfg(feature = "parquet")] +fn scan_nutri_score_null_column_parquet(parallel: bool) -> LazyFrame { + init_files(); + let out_path = NUTRI_SCORE_NULL_COLUMN_PARQUET; + let parallel = if parallel { + ParallelStrategy::Auto + } else { + ParallelStrategy::None + }; + + let args = ScanArgsParquet { + n_rows: None, + cache: false, + parallel, + rechunk: true, + ..Default::default() + }; + LazyFrame::scan_parquet(out_path, args).unwrap() +} + pub(crate) fn fruits_cars() -> DataFrame { df!( "A"=> [1, 2, 3, 4, 5], diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 4d997343e68b..c42948acb524 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1,10 +1,7 @@ -use polars_core::frame::explode::MeltArgs; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use super::*; -#[cfg(feature = "range")] -use crate::dsl::arg_sort_by; #[test] fn test_lazy_with_column() { diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 018da9daf94e..8bc431ca50e4 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -46,7 +46,7 @@ rand = { workspace = true, features = ["small_rng"] } version_check = { workspace = true } [features] -simd = ["argminmax/nightly_simd"] +simd = [] nightly = ["polars-utils/nightly"] dtype-categorical = ["polars-core/dtype-categorical"] dtype-date = ["polars-core/dtype-date", "polars-core/temporal"] @@ -87,6 +87,7 @@ string_encoding = ["base64", "hex"] to_dummies = [] interpolate = [] list_to_struct = ["polars-core/dtype-struct"] +array_to_struct = ["polars-core/dtype-array", "polars-core/dtype-struct"] list_count = [] diff = [] pct_change = ["diff"] @@ -107,7 +108,7 @@ merge_sorted = [] top_k = [] pivot = ["polars-core/reinterpret"] cross_join = [] -chunked_ids = ["polars-core/chunked_ids"] +chunked_ids = [] asof_join = ["polars-core/asof_join"] semi_anti_join = [] array_any_all = ["dtype-array"] diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs new file mode 100644 index 000000000000..e7039ac5db2e --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -0,0 +1,94 @@ +use super::*; + +pub(super) fn median_with_nulls(ca: &ArrayChunked) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) + .with_name(ca.name()); + out.into_series() + }, + }; + out.rename(ca.name()); + Ok(out) +} + +pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().std(ddof))) + .collect(); + out.into_series() + }, + }; + out.rename(ca.name()); + Ok(out) +} + +pub(super) fn var_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Milliseconds) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Microseconds | TimeUnit::Nanoseconds) => { + let out: Int64Chunked = ca + .cast(&DataType::Array( + Box::new(DataType::Duration(TimeUnit::Milliseconds)), + ca.width(), + )) + .unwrap() + .array() + .unwrap() + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) + .with_name(ca.name()); + out.into_series() + }, + }; + out.rename(ca.name()); + Ok(out) +} diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs index 6cb5630676e9..f8fc2e894acf 100644 --- a/crates/polars-ops/src/chunked_array/array/get.rs +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -1,7 +1,6 @@ use arrow::legacy::kernels::fixed_size_list::{ sub_fixed_size_list_get, sub_fixed_size_list_get_literal, }; -use polars_core::datatypes::ArrayChunked; use polars_core::prelude::arity::binary_to_series; use super::*; diff --git a/crates/polars-ops/src/chunked_array/array/join.rs b/crates/polars-ops/src/chunked_array/array/join.rs index e3ef1b086daf..69b4d5d3815b 100644 --- a/crates/polars-ops/src/chunked_array/array/join.rs +++ b/crates/polars-ops/src/chunked_array/array/join.rs @@ -1,7 +1,5 @@ use std::fmt::Write; -use polars_core::prelude::ArrayChunked; - use super::*; fn join_literal( @@ -25,13 +23,13 @@ fn join_literal( if ca.null_count() != 0 && !ignore_nulls { return None; } - - let iter = ca.into_iter().flatten(); - - for val in iter { - buf.write_str(val).unwrap(); - buf.write_str(separator).unwrap(); + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } } + // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. Some(&buf[..buf.len().saturating_sub(separator.len())]) @@ -62,11 +60,11 @@ fn join_many( return None; } - let iter = ca.into_iter().flatten(); - - for val in iter { - buf.write_str(val).unwrap(); - buf.write_str(separator).unwrap(); + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } } // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. diff --git a/crates/polars-ops/src/chunked_array/array/min_max.rs b/crates/polars-ops/src/chunked_array/array/min_max.rs index c4857fc94ff9..c61d422e4277 100644 --- a/crates/polars-ops/src/chunked_array/array/min_max.rs +++ b/crates/polars-ops/src/chunked_array/array/min_max.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use polars_compute::min_max::MinMaxKernel; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-ops/src/chunked_array/array/mod.rs b/crates/polars-ops/src/chunked_array/array/mod.rs index b48fcf9bb0d3..efe4dcbf339c 100644 --- a/crates/polars-ops/src/chunked_array/array/mod.rs +++ b/crates/polars-ops/src/chunked_array/array/mod.rs @@ -1,14 +1,19 @@ #[cfg(feature = "array_any_all")] mod any_all; mod count; +mod dispersion; mod get; mod join; mod min_max; mod namespace; mod sum_mean; +#[cfg(feature = "array_to_struct")] +mod to_struct; pub use namespace::ArrayNameSpace; use polars_core::prelude::*; +#[cfg(feature = "array_to_struct")] +pub use to_struct::*; pub trait AsArray { fn as_array(&self) -> &ArrayChunked; diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index 8ad949025131..49c30cd00e0a 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -51,6 +51,21 @@ pub trait ArrayNameSpace: AsArray { } } + fn array_median(&self) -> PolarsResult { + let ca = self.as_array(); + dispersion::median_with_nulls(ca) + } + + fn array_std(&self, ddof: u8) -> PolarsResult { + let ca = self.as_array(); + dispersion::std_with_nulls(ca, ddof) + } + + fn array_var(&self, ddof: u8) -> PolarsResult { + let ca = self.as_array(); + dispersion::var_with_nulls(ca, ddof) + } + fn array_unique(&self) -> PolarsResult { let ca = self.as_array(); ca.try_apply_amortized_to_list(|s| s.as_ref().unique()) @@ -114,6 +129,39 @@ pub trait ArrayNameSpace: AsArray { let ca = self.as_array(); array_count_matches(ca, element) } + + fn array_shift(&self, n: &Series) -> PolarsResult { + let ca = self.as_array(); + let n_s = n.cast(&DataType::Int64)?; + let n = n_s.i64()?; + let out = match n.len() { + 1 => { + if let Some(n) = n.get(0) { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { ca.apply_amortized_same_type(|s| s.as_ref().shift(n)) } + } else { + ArrayChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + ca.width(), + ) + } + }, + _ => { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { + ca.zip_and_apply_amortized_same_type(n, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(n)) => Some(s.as_ref().shift(n)), + _ => None, + } + }) + } + }, + }; + Ok(out.into_series()) + } } impl ArrayNameSpace for ArrayChunked {} diff --git a/crates/polars-ops/src/chunked_array/array/sum_mean.rs b/crates/polars-ops/src/chunked_array/array/sum_mean.rs index f998e0729bb9..d27f1117fd3a 100644 --- a/crates/polars-ops/src/chunked_array/array/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/array/sum_mean.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::legacy::utils::CustomIterTools; use arrow::types::NativeType; diff --git a/crates/polars-ops/src/chunked_array/array/to_struct.rs b/crates/polars-ops/src/chunked_array/array/to_struct.rs new file mode 100644 index 000000000000..b14e388ff82b --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/to_struct.rs @@ -0,0 +1,44 @@ +use polars_core::export::rayon::prelude::*; +use polars_core::POOL; +use polars_utils::format_smartstring; +use smartstring::alias::String as SmartString; + +use super::*; + +pub type ArrToStructNameGenerator = Arc SmartString + Send + Sync>; + +pub fn arr_default_struct_name_gen(idx: usize) -> SmartString { + format_smartstring!("field_{idx}") +} + +pub trait ToStruct: AsArray { + fn to_struct( + &self, + name_generator: Option, + ) -> PolarsResult { + let ca = self.as_array(); + let n_fields = ca.width(); + + let name_generator = name_generator + .as_deref() + .unwrap_or(&arr_default_struct_name_gen); + + polars_ensure!(n_fields != 0, ComputeError: "cannot create a struct with 0 fields"); + let fields = POOL.install(|| { + (0..n_fields) + .into_par_iter() + .map(|i| { + ca.array_get(&Int64Chunked::from_slice("", &[i as i64])) + .map(|mut s| { + s.rename(&name_generator(i)); + s + }) + }) + .collect::>>() + })?; + + StructChunked::new(ca.name(), &fields) + } +} + +impl ToStruct for ArrayChunked {} diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs new file mode 100644 index 000000000000..47dfccd58293 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -0,0 +1,275 @@ +use polars_core::prelude::gather::_update_gather_sorted_flag; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::index::ChunkId; +use polars_utils::slice::GetSaferUnchecked; + +use crate::frame::IntoDf; + +pub trait DfTake: IntoDf { + /// Take elements by a slice of [`ChunkId`]s. + /// + /// # Safety + /// Does not do any bound checks. + /// `sorted` indicates if the chunks are sorted. + unsafe fn _take_chunked_unchecked_seq(&self, idx: &[ChunkId], sorted: IsSorted) -> DataFrame { + let cols = self + .to_df() + ._apply_columns(&|s| s.take_chunked_unchecked(idx, sorted)); + + unsafe { DataFrame::new_no_checks(cols) } + } + /// Take elements by a slice of optional [`ChunkId`]s. + /// + /// # Safety + /// Does not do any bound checks. + unsafe fn _take_opt_chunked_unchecked_seq(&self, idx: &[Option]) -> DataFrame { + let cols = self + .to_df() + ._apply_columns(&|s| s.take_opt_chunked_unchecked(idx)); + + unsafe { DataFrame::new_no_checks(cols) } + } + + /// # Safety + /// Doesn't perform any bound checks + unsafe fn _take_chunked_unchecked(&self, idx: &[ChunkId], sorted: IsSorted) -> DataFrame { + let cols = self + .to_df() + ._apply_columns_par(&|s| s.take_chunked_unchecked(idx, sorted)); + + unsafe { DataFrame::new_no_checks(cols) } + } + + /// # Safety + /// Doesn't perform any bound checks + unsafe fn _take_opt_chunked_unchecked(&self, idx: &[Option]) -> DataFrame { + let cols = self + .to_df() + ._apply_columns_par(&|s| s.take_opt_chunked_unchecked(idx)); + + unsafe { DataFrame::new_no_checks(cols) } + } +} + +impl DfTake for DataFrame {} + +/// Gather by [`ChunkId`] +pub trait TakeChunked { + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self; + + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self; +} + +impl TakeChunked for Series { + unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { + let phys = self.to_physical_repr(); + use DataType::*; + let out = match phys.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(phys.dtype(), |$T| { + let ca: &ChunkedArray<$T> = phys.as_ref().as_ref().as_ref(); + ca.take_chunked_unchecked(by, sorted).into_series() + }) + }, + Boolean => { + let ca = phys.bool().unwrap(); + ca.take_chunked_unchecked(by, sorted).into_series() + }, + Binary => { + let ca = phys.binary().unwrap(); + ca.take_chunked_unchecked(by, sorted).into_series() + }, + String => { + let ca = phys.str().unwrap(); + ca.take_chunked_unchecked(by, sorted).into_series() + }, + List(_) => { + let ca = phys.list().unwrap(); + ca.take_chunked_unchecked(by, sorted).into_series() + }, + #[cfg(feature = "dtype-array")] + Array(_, _) => { + let ca = phys.array().unwrap(); + ca.take_chunked_unchecked(by, sorted).into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = phys.struct_().unwrap(); + ca._apply_fields(|s| s.take_chunked_unchecked(by, sorted)) + .into_series() + }, + #[cfg(feature = "object")] + Object(_, _) => take_unchecked_object(&phys, by, sorted), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = phys.decimal().unwrap(); + let out = ca.0.take_chunked_unchecked(by, sorted); + out.into_decimal_unchecked(ca.precision(), ca.scale()) + .into_series() + }, + Null => Series::new_null(self.name(), by.len()), + _ => unreachable!(), + }; + unsafe { out.cast_unchecked(self.dtype()).unwrap() } + } + + unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { + let phys = self.to_physical_repr(); + use DataType::*; + let out = match phys.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(phys.dtype(), |$T| { + let ca: &ChunkedArray<$T> = phys.as_ref().as_ref().as_ref(); + ca.take_opt_chunked_unchecked(by).into_series() + }) + }, + Boolean => { + let ca = phys.bool().unwrap(); + ca.take_opt_chunked_unchecked(by).into_series() + }, + Binary => { + let ca = phys.binary().unwrap(); + ca.take_opt_chunked_unchecked(by).into_series() + }, + String => { + let ca = phys.str().unwrap(); + ca.take_opt_chunked_unchecked(by).into_series() + }, + List(_) => { + let ca = phys.list().unwrap(); + ca.take_opt_chunked_unchecked(by).into_series() + }, + #[cfg(feature = "dtype-array")] + Array(_, _) => { + let ca = phys.array().unwrap(); + ca.take_opt_chunked_unchecked(by).into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = phys.struct_().unwrap(); + ca._apply_fields(|s| s.take_opt_chunked_unchecked(by)) + .into_series() + }, + #[cfg(feature = "object")] + Object(_, _) => take_opt_unchecked_object(&phys, by), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = phys.decimal().unwrap(); + let out = ca.0.take_opt_chunked_unchecked(by); + out.into_decimal_unchecked(ca.precision(), ca.scale()) + .into_series() + }, + Null => Series::new_null(self.name(), by.len()), + _ => unreachable!(), + }; + unsafe { out.cast_unchecked(self.dtype()).unwrap() } + } +} + +impl TakeChunked for ChunkedArray +where + T: PolarsDataType, +{ + unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { + let arrow_dtype = self.dtype().to_arrow(true); + + let mut out = if let Some(iter) = self.downcast_slices() { + let targets = iter.collect::>(); + let iter = by.iter().map(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let vals = targets.get_unchecked_release(chunk_idx as usize); + vals.get_unchecked_release(array_idx as usize).clone() + }); + + let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); + ChunkedArray::with_chunk(self.name(), arr) + } else { + let targets = self.downcast_iter().collect::>(); + let iter = by.iter().map(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let vals = targets.get_unchecked_release(chunk_idx as usize); + vals.get_unchecked(array_idx as usize) + }); + let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); + ChunkedArray::with_chunk(self.name(), arr) + }; + let sorted_flag = _update_gather_sorted_flag(self.is_sorted_flag(), sorted); + out.set_sorted_flag(sorted_flag); + out + } + + unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { + let arrow_dtype = self.dtype().to_arrow(true); + + if let Some(iter) = self.downcast_slices() { + let targets = iter.collect::>(); + let arr = by + .iter() + .map(|chunk_id| { + chunk_id.map(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let vals = *targets.get_unchecked_release(chunk_idx as usize); + vals.get_unchecked_release(array_idx as usize).clone() + }) + }) + .collect_arr_trusted_with_dtype(arrow_dtype); + + ChunkedArray::with_chunk(self.name(), arr) + } else { + let targets = self.downcast_iter().collect::>(); + let arr = by + .iter() + .map(|chunk_id| { + chunk_id.and_then(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let vals = *targets.get_unchecked_release(chunk_idx as usize); + vals.get_unchecked(array_idx as usize) + }) + }) + .collect_arr_trusted_with_dtype(arrow_dtype); + + ChunkedArray::with_chunk(self.name(), arr) + } + } +} + +#[cfg(feature = "object")] +unsafe fn take_unchecked_object(s: &Series, by: &[ChunkId], _sorted: IsSorted) -> Series { + let DataType::Object(_, reg) = s.dtype() else { + unreachable!() + }; + let reg = reg.as_ref().unwrap(); + let mut builder = (*reg.builder_constructor)(s.name(), by.len()); + + by.iter().for_each(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let object = s.get_object_chunked_unchecked(chunk_idx as usize, array_idx as usize); + builder.append_option(object.map(|v| v.as_any())) + }); + builder.to_series() +} + +#[cfg(feature = "object")] +unsafe fn take_opt_unchecked_object(s: &Series, by: &[Option]) -> Series { + let DataType::Object(_, reg) = s.dtype() else { + unreachable!() + }; + let reg = reg.as_ref().unwrap(); + let mut builder = (*reg.builder_constructor)(s.name(), by.len()); + + by.iter().for_each(|chunk_id| match chunk_id { + None => builder.append_null(), + Some(chunk_id) => { + let (chunk_idx, array_idx) = chunk_id.extract(); + let object = s.get_object_chunked_unchecked(chunk_idx as usize, array_idx as usize); + builder.append_option(object.map(|v| v.as_any())) + }, + }); + builder.to_series() +} diff --git a/crates/polars-ops/src/chunked_array/gather/mod.rs b/crates/polars-ops/src/chunked_array/gather/mod.rs new file mode 100644 index 000000000000..fe4a565d63bb --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "chunked_ids")] +pub(crate) mod chunked; +#[cfg(feature = "chunked_ids")] +pub use chunked::*; diff --git a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs index 2dfe4208da1d..ba1427e6f6fd 100644 --- a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs +++ b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs @@ -155,7 +155,6 @@ mod test { use rand::distributions::uniform::SampleUniform; use rand::prelude::*; - use rand::rngs::SmallRng; use super::*; @@ -183,12 +182,11 @@ mod test { } fn test_equal_ref(ca: &UInt32Chunked, idx_ca: &IdxCa) { - let ref_ca: Vec> = ca.into_iter().collect(); - let ref_idx_ca: Vec> = - idx_ca.into_iter().map(|i| Some(i? as usize)).collect(); + let ref_ca: Vec> = ca.iter().collect(); + let ref_idx_ca: Vec> = idx_ca.iter().map(|i| Some(i? as usize)).collect(); let gather = ca.gather_skip_nulls(idx_ca).ok(); let ref_gather = ref_gather_nulls(ref_ca, ref_idx_ca); - assert_eq!(gather.map(|ca| ca.into_iter().collect()), ref_gather); + assert_eq!(gather.map(|ca| ca.iter().collect()), ref_gather); } fn gather_skip_nulls_check(ca: &UInt32Chunked, idx_ca: &IdxCa) { diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index c16c38a1ae84..5833a1bc784d 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -1,16 +1,10 @@ use std::fmt::Write; -use arrow::legacy::index::IdxSize; use num_traits::ToPrimitive; -use polars_core::datatypes::PolarsNumericType; -use polars_core::prelude::{ - ChunkCast, ChunkSort, ChunkedArray, DataType, StringChunkedBuilder, StructChunked, UInt32Type, - *, -}; +use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; -use polars_error::PolarsResult; use polars_utils::float::IsFloat; -use polars_utils::total_ord::TotalOrdWrap; +use polars_utils::total_ord::ToTotalOrd; fn compute_hist( ca: &ChunkedArray, @@ -26,7 +20,7 @@ where let (breaks, count) = if let Some(bins) = bins { let mut breaks = Vec::with_capacity(bins.len() + 1); breaks.extend_from_slice(bins); - breaks.sort_unstable_by_key(|k| TotalOrdWrap(*k)); + breaks.sort_unstable_by_key(|k| k.to_total_ord()); breaks.push(f64::INFINITY); let sorted = ca.sort(false); @@ -67,6 +61,10 @@ where count.push(0) } (breaks, count) + } else if ca.null_count() == ca.len() { + let breaks: Vec = vec![f64::INFINITY]; + let count: Vec = vec![0]; + (breaks, count) } else { let min = ChunkAgg::min(ca).unwrap().to_f64().unwrap(); let max = ChunkAgg::max(ca).unwrap().to_f64().unwrap(); diff --git a/crates/polars-ops/src/chunked_array/list/dispersion.rs b/crates/polars-ops/src/chunked_array/list/dispersion.rs new file mode 100644 index 000000000000..3d47520c1d92 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/dispersion.rs @@ -0,0 +1,86 @@ +use super::*; + +pub(super) fn median_with_nulls(ca: &ListChunked) -> Series { + return match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) + .with_name(ca.name()); + out.into_series() + }, + }; +} + +pub(super) fn std_with_nulls(ca: &ListChunked, ddof: u8) -> Series { + return match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof))) + .with_name(ca.name()); + out.into_series() + }, + }; +} + +pub(super) fn var_with_nulls(ca: &ListChunked, ddof: u8) -> Series { + return match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) + .with_name(ca.name()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Milliseconds) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Microseconds | TimeUnit::Nanoseconds) => { + let out: Int64Chunked = ca + .cast(&DataType::List(Box::new(DataType::Duration( + TimeUnit::Milliseconds, + )))) + .unwrap() + .list() + .unwrap() + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) + .with_name(ca.name()); + out.into_series() + }, + }; +} diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs index 5931753f6ebf..fe00dcdceeb6 100644 --- a/crates/polars-ops/src/chunked_array/list/hash.rs +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -1,17 +1,18 @@ use std::hash::Hash; use polars_core::export::_boost_hash_combine; -use polars_core::export::ahash::{self}; use polars_core::export::rayon::prelude::*; use polars_core::utils::NoNull; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; fn hash_agg(ca: &ChunkedArray, random_state: &ahash::RandomState) -> u64 where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -30,7 +31,7 @@ where for opt_v in arr.iter() { match opt_v { Some(v) => { - let r = random_state.hash_one(v); + let r = random_state.hash_one(v.to_total_ord()); hash_agg = _boost_hash_combine(hash_agg, r); }, None => { @@ -60,7 +61,12 @@ pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UI .map(|opt_s: Option| match opt_s { None => null_hash, Some(s) => { - if s.bit_repr_is_large() { + if s.dtype().is_float() { + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + hash_agg(ca, &build_hasher) + }) + } else if s.bit_repr_is_large() { let ca = s.bit_repr_large(); hash_agg(&ca, &build_hasher) } else { diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index 51db2f079b08..dd043110be2e 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -1,6 +1,5 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; -use arrow::legacy::array::PolarsArray; use arrow::types::NativeType; use polars_compute::min_max::MinMaxKernel; use polars_core::prelude::*; diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs index bd0e167528f7..a93b1ed7e2b3 100644 --- a/crates/polars-ops/src/chunked_array/list/mod.rs +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -3,6 +3,7 @@ use polars_core::prelude::*; #[cfg(feature = "list_any_all")] mod any_all; mod count; +mod dispersion; #[cfg(feature = "hash")] pub(crate) mod hash; mod min_max; diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 00bca56ec465..38ca7732c40c 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::fmt::Write; use arrow::array::ValueSize; @@ -112,12 +111,13 @@ pub trait ListNameSpaceImpl: AsList { return None; } - let iter = ca.into_iter().flatten(); - - for val in iter { - buf.write_str(val).unwrap(); - buf.write_str(separator).unwrap(); + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } } + // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. Some(&buf[..buf.len().saturating_sub(separator.len())]) @@ -151,12 +151,13 @@ pub trait ListNameSpaceImpl: AsList { return None; } - let iter = ca.into_iter().flatten(); - - for val in iter { - buf.write_str(val).unwrap(); - buf.write_str(separator).unwrap(); + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } } + // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. Some(&buf[..buf.len().saturating_sub(separator.len())]) @@ -216,6 +217,21 @@ pub trait ListNameSpaceImpl: AsList { } } + fn lst_median(&self) -> Series { + let ca = self.as_list(); + dispersion::median_with_nulls(ca) + } + + fn lst_std(&self, ddof: u8) -> Series { + let ca = self.as_list(); + dispersion::std_with_nulls(ca, ddof) + } + + fn lst_var(&self, ddof: u8) -> Series { + let ca = self.as_list(); + dispersion::var_with_nulls(ca, ddof) + } + fn same_type(&self, out: ListChunked) -> ListChunked { let ca = self.as_list(); let dtype = ca.dtype(); @@ -240,6 +256,14 @@ pub trait ListNameSpaceImpl: AsList { self.same_type(out) } + fn lst_n_unique(&self) -> PolarsResult { + let ca = self.as_list(); + ca.try_apply_amortized_generic(|s| { + let opt_v = s.map(|s| s.as_ref().n_unique()).transpose()?; + Ok(opt_v.map(|idx| idx as IdxSize)) + }) + } + fn lst_unique(&self) -> PolarsResult { let ca = self.as_list(); let out = ca.try_apply_amortized(|s| s.as_ref().unique())?; @@ -257,7 +281,6 @@ pub trait ListNameSpaceImpl: AsList { ca.apply_amortized_generic(|opt_s| { opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize)) }) - .with_name(ca.name()) } fn lst_arg_max(&self) -> IdxCa { @@ -265,7 +288,6 @@ pub trait ListNameSpaceImpl: AsList { ca.apply_amortized_generic(|opt_s| { opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)) }) - .with_name(ca.name()) } #[cfg(feature = "diff")] @@ -326,7 +348,7 @@ pub trait ListNameSpaceImpl: AsList { .downcast_iter() .map(|arr| sublist_get(arr, idx)) .collect::>(); - // Safety: every element in list has dtype equal to its inner type + // SAFETY: every element in list has dtype equal to its inner type unsafe { Series::try_from((ca.name(), chunks)) .unwrap() @@ -334,6 +356,69 @@ pub trait ListNameSpaceImpl: AsList { } } + #[cfg(feature = "list_gather")] + fn lst_gather_every(&self, n: &IdxCa, offset: &IdxCa) -> PolarsResult { + let list_ca = self.as_list(); + let out = match (n.len(), offset.len()) { + (1, 1) => match (n.get(0), offset.get(0)) { + (Some(n), Some(offset)) => list_ca + .apply_amortized(|s| s.as_ref().gather_every(n as usize, offset as usize)), + _ => ListChunked::full_null_with_dtype( + list_ca.name(), + list_ca.len(), + &list_ca.inner_dtype(), + ), + }, + (1, len_offset) if len_offset == list_ca.len() => { + if let Some(n) = n.get(0) { + list_ca.zip_and_apply_amortized(offset, |opt_s, opt_offset| { + match (opt_s, opt_offset) { + (Some(s), Some(offset)) => { + Some(s.as_ref().gather_every(n as usize, offset as usize)) + }, + _ => None, + } + }) + } else { + ListChunked::full_null_with_dtype( + list_ca.name(), + list_ca.len(), + &list_ca.inner_dtype(), + ) + } + }, + (len_n, 1) if len_n == list_ca.len() => { + if let Some(offset) = offset.get(0) { + list_ca.zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(n)) => { + Some(s.as_ref().gather_every(n as usize, offset as usize)) + }, + _ => None, + }) + } else { + ListChunked::full_null_with_dtype( + list_ca.name(), + list_ca.len(), + &list_ca.inner_dtype(), + ) + } + }, + (len_n, len_offset) if len_n == len_offset && len_n == list_ca.len() => list_ca + .binary_zip_and_apply_amortized(n, offset, |opt_s, opt_n, opt_offset| { + match (opt_s, opt_n, opt_offset) { + (Some(s), Some(n), Some(offset)) => { + Some(s.as_ref().gather_every(n as usize, offset as usize)) + }, + _ => None, + } + }), + _ => { + polars_bail!(ComputeError: "The lengths of `n` and `offset` should be 1 or equal to the length of list.") + }, + }; + Ok(out.into_series()) + } + #[cfg(feature = "list_gather")] fn lst_gather(&self, idx: &Series, null_on_oob: bool) -> PolarsResult { let list_ca = self.as_list(); @@ -599,7 +684,7 @@ pub trait ListNameSpaceImpl: AsList { // SAFETY: unstable series never lives longer than the iterator. iters.push(unsafe { s.list()?.amortized_iter() }) } - let mut first_iter = ca.into_iter(); + let mut first_iter: Box>> = ca.into_iter(); let mut builder = get_list_builder( &inner_super_type, ca.get_values_size() + vals_size_other + 1, @@ -668,7 +753,7 @@ fn cast_signed_index_ca(idx: &ChunkedArray, len: usize) where T::Native: Copy + PartialOrd + PartialEq + NumCast + Signed + Zero, { - idx.into_iter() + idx.iter() .map(|opt_idx| opt_idx.and_then(|idx| idx.negative_to_usize(len).map(|idx| idx as IdxSize))) .collect::() .into_series() @@ -679,7 +764,7 @@ fn cast_unsigned_index_ca(idx: &ChunkedArray, len: usiz where T::Native: Copy + PartialOrd + ToPrimitive, { - idx.into_iter() + idx.iter() .map(|opt_idx| { opt_idx.and_then(|idx| { let idx = idx.to_usize().unwrap(); diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index 9fb0373d9d6e..e1ba4f476af8 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -11,7 +11,7 @@ use arrow::offset::OffsetsBuffer; use arrow::types::NativeType; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_type; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -30,12 +30,12 @@ where } } -impl MaterializeValues>> for MutablePrimitiveArray +impl MaterializeValues>> for MutablePrimitiveArray where T: NativeType, { - fn extend_buf>>>(&mut self, values: I) -> usize { - self.extend(values); + fn extend_buf>>>(&mut self, values: I) -> usize { + self.extend(values.map(|x| x.0)); self.len() } } @@ -84,7 +84,7 @@ where SetOperation::Difference => { set.extend(a); for v in b { - set.remove(&v); + set.swap_remove(&v); } out.extend_buf(set.drain(..)) }, @@ -102,8 +102,10 @@ where } } -fn copied_wrapper_opt(v: Option<&T>) -> Option> { - v.copied().map(TotalOrdWrap) +fn copied_wrapper_opt( + v: Option<&T>, +) -> as ToTotalOrd>::TotalOrdItem { + v.copied().to_total_ord() } #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] @@ -136,13 +138,14 @@ fn primitive( validity: Option, ) -> PolarsResult> where - T: NativeType + TotalHash + Copy + TotalEq, + T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let broadcast_lhs = offsets_a.len() == 2; let broadcast_rhs = offsets_b.len() == 2; let mut set = Default::default(); - let mut set2: PlIndexSet>> = Default::default(); + let mut set2: PlIndexSet< as ToTotalOrd>::TotalOrdItem> = Default::default(); let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max( *offsets_a.last().unwrap(), @@ -151,9 +154,6 @@ where let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len())); offsets.push(0i64); - if broadcast_rhs { - set2.extend(b.into_iter().map(copied_wrapper_opt)); - } let offsets_slice = if offsets_a.len() > offsets_b.len() { offsets_a } else { @@ -163,6 +163,14 @@ where let second_a = offsets_a[1]; let first_b = offsets_b[0]; let second_b = offsets_b[1]; + if broadcast_rhs { + set2.extend( + b.into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize) + .map(copied_wrapper_opt), + ); + } for i in 1..offsets_slice.len() { // If we go OOB we take the first element as we are then broadcasting. let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize; @@ -180,7 +188,11 @@ where .skip(start_a) .take(end_a - start_a) .map(copied_wrapper_opt); - let b_iter = b.into_iter().map(copied_wrapper_opt); + let b_iter = b + .into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize) + .map(copied_wrapper_opt); set_operation( &mut set, &mut set2, @@ -191,7 +203,11 @@ where true, ) } else if broadcast_lhs { - let a_iter = a.into_iter().map(copied_wrapper_opt); + let a_iter = a + .into_iter() + .skip(first_a as usize) + .take(second_a as usize - first_a as usize) + .map(copied_wrapper_opt); let b_iter = b .into_iter() @@ -368,7 +384,7 @@ fn array_set_operation( binary(&a, &b, offsets_a, offsets_b, set_op, validity, true) }, - ArrowDataType::LargeBinary => { + ArrowDataType::BinaryView => { let a = values_a.as_any().downcast_ref::().unwrap(); let b = values_b.as_any().downcast_ref::().unwrap(); binary(a, b, offsets_a, offsets_b, set_op, validity, false) @@ -404,6 +420,13 @@ pub fn list_set_operation( a.prune_empty_chunks(); b.prune_empty_chunks(); + // Make categoricals compatible + if let (DataType::Categorical(_, _), DataType::Categorical(_, _)) = + (&a.inner_dtype(), &b.inner_dtype()) + { + (a, b) = make_list_categoricals_compatible(a, b)?; + } + // we use the unsafe variant because we want to keep the nested logical types type. unsafe { arity::try_binary_unchecked_same_type( diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index fe48e397459b..e3a14e2340f7 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -1,10 +1,9 @@ use std::ops::Div; -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::legacy::utils::CustomIterTools; use arrow::types::NativeType; -use polars_core::datatypes::ListChunked; use polars_core::export::num::{NumCast, ToPrimitive}; use polars_utils::unwrap::UnwrapUncheckedRelease; diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index 765ca2394674..c43cfda13024 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -1,4 +1,5 @@ use polars_core::export::rayon::prelude::*; +use polars_core::POOL; use polars_utils::format_smartstring; use smartstring::alias::String as SmartString; @@ -67,15 +68,17 @@ pub trait ToStruct: AsList { .unwrap_or(&_default_struct_name_gen); polars_ensure!(n_fields != 0, ComputeError: "cannot create a struct with 0 fields"); - let fields = (0..n_fields) - .into_par_iter() - .map(|i| { - ca.lst_get(i as i64).map(|mut s| { - s.rename(&name_generator(i)); - s + let fields = POOL.install(|| { + (0..n_fields) + .into_par_iter() + .map(|i| { + ca.lst_get(i as i64).map(|mut s| { + s.rename(&name_generator(i)); + s + }) }) - }) - .collect::>>()?; + .collect::>>() + })?; StructChunked::new(ca.name(), &fields) } diff --git a/crates/polars-ops/src/chunked_array/mod.rs b/crates/polars-ops/src/chunked_array/mod.rs index 9a6a3c0f727e..31729d7c7c67 100644 --- a/crates/polars-ops/src/chunked_array/mod.rs +++ b/crates/polars-ops/src/chunked_array/mod.rs @@ -21,6 +21,7 @@ pub mod mode; #[cfg(feature = "cov")] pub mod cov; +pub(crate) mod gather; #[cfg(feature = "gather")] pub mod gather_skip_nulls; #[cfg(feature = "hist")] @@ -31,6 +32,8 @@ mod repeat_by; pub use binary::*; #[cfg(feature = "timezones")] pub use datetime::*; +#[cfg(feature = "chunked_ids")] +pub use gather::*; #[cfg(feature = "hist")] pub use hist::*; #[cfg(feature = "interpolate")] diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index 6e83d59744f2..26b728306c5e 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -1,5 +1,4 @@ use arrow::legacy::utils::CustomIterTools; -use polars_core::frame::group_by::IntoGroupsProxy; use polars_core::prelude::*; use polars_core::{with_match_physical_integer_polars_type, POOL}; @@ -14,7 +13,7 @@ where let groups = ca.group_tuples(parallel, false).unwrap(); let idx = mode_indices(groups); - // Safety: + // SAFETY: // group indices are in bounds Ok(unsafe { ca.take_unchecked(idx.as_slice()) }) } diff --git a/crates/polars-ops/src/chunked_array/repeat_by.rs b/crates/polars-ops/src/chunked_array/repeat_by.rs index bd844501f94d..bdba858d5719 100644 --- a/crates/polars-ops/src/chunked_array/repeat_by.rs +++ b/crates/polars-ops/src/chunked_array/repeat_by.rs @@ -1,5 +1,4 @@ use arrow::array::ListArray; -use arrow::legacy::array::ListFromIter; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-ops/src/chunked_array/scatter.rs b/crates/polars-ops/src/chunked_array/scatter.rs index 26ea76cd66ce..080f2b1da9c5 100644 --- a/crates/polars-ops/src/chunked_array/scatter.rs +++ b/crates/polars-ops/src/chunked_array/scatter.rs @@ -100,7 +100,7 @@ where let mut ca = self.rechunk(); drop(self); - // safety: + // SAFETY: // we will not modify the length // and we unset the sorted flag. ca.set_sorted_flag(IsSorted::Not); @@ -113,13 +113,13 @@ where // reborrow because the bck does not allow it let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) }; - // Safety: + // SAFETY: // we checked bounds unsafe { scatter_impl(current_values, values, arr, idx, len) }; }, None => { let mut new_values = arr.values().as_slice().to_vec(); - // Safety: + // SAFETY: // we checked bounds unsafe { scatter_impl(&mut new_values, values, arr, idx, len) }; arr.set_values(new_values.into()); diff --git a/crates/polars-ops/src/chunked_array/strings/concat.rs b/crates/polars-ops/src/chunked_array/strings/concat.rs index 7689ab4fe883..c290b076cac4 100644 --- a/crates/polars-ops/src/chunked_array/strings/concat.rs +++ b/crates/polars-ops/src/chunked_array/strings/concat.rs @@ -1,6 +1,5 @@ use arrow::array::{Utf8Array, ValueSize}; use arrow::compute::cast::utf8_to_utf8view; -use arrow::legacy::array::default_arrays::FromDataUtf8; use polars_core::prelude::*; // Vertically concatenate all strings in a StringChunked. @@ -91,7 +90,7 @@ pub fn hor_str_concat( .iter() .map(|ca| { if ca.len() > 1 { - ColumnIter::Iter(ca.into_iter()) + ColumnIter::Iter(ca.iter()) } else { ColumnIter::Broadcast(ca.get(0)) } diff --git a/crates/polars-ops/src/chunked_array/strings/json_path.rs b/crates/polars-ops/src/chunked_array/strings/json_path.rs index c32bc9919c35..3b8edcaea962 100644 --- a/crates/polars-ops/src/chunked_array/strings/json_path.rs +++ b/crates/polars-ops/src/chunked_array/strings/json_path.rs @@ -55,7 +55,7 @@ pub trait Utf8JsonPathImpl: AsString { fn json_infer(&self, number_of_rows: Option) -> PolarsResult { let ca = self.as_string(); let values_iter = ca - .into_iter() + .iter() .map(|x| x.unwrap_or("null")) .take(number_of_rows.unwrap_or(ca.len())); @@ -76,7 +76,7 @@ pub trait Utf8JsonPathImpl: AsString { None => ca.json_infer(infer_schema_len)?, }; let buf_size = ca.get_values_size() + ca.null_count() * "null".len(); - let iter = ca.into_iter().map(|x| x.unwrap_or("null")); + let iter = ca.iter().map(|x| x.unwrap_or("null")); let array = polars_json::ndjson::deserialize::deserialize_iter( iter, diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 8c59e1fdbeea..83713b788952 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -383,10 +383,12 @@ pub trait StringNameSpaceImpl: AsString { let reg = Regex::new(pat)?; let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); - for opt_s in ca.into_iter() { - match opt_s { - None => builder.append_null(), - Some(s) => builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())), + for arr in ca.downcast_iter() { + for opt_s in arr { + match opt_s { + None => builder.append_null(), + Some(s) => builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())), + } } } Ok(builder.finish()) diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 04618c46f467..32354f057a65 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -15,12 +15,11 @@ pub type ChunkJoinOptIds = Vec>; #[cfg(not(feature = "chunked_ids"))] pub type ChunkJoinIds = Vec; +#[cfg(feature = "chunked_ids")] +use polars_utils::index::ChunkId; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "asof_join")] -use super::asof::AsOfOptions; - #[derive(Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinArgs { diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index f64f7d100984..63b3e512fa7e 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -4,15 +4,15 @@ use ahash::RandomState; use num_traits::Zero; use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; use polars_core::utils::{split_ca, split_df}; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; use polars_utils::abs_diff::AbsDiff; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use rayon::prelude::*; use smartstring::alias::String as SmartString; use super::*; -use crate::frame::IntoDf; fn compute_len_offsets>(iter: I) -> Vec { let mut cumlen = 0; @@ -71,7 +71,8 @@ fn asof_join_by_numeric( where T: PolarsDataType, S: PolarsNumericType, - S::Native: Hash + Eq + DirtyHash + IsNull, + S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { @@ -95,49 +96,47 @@ where let n_tables = hash_tbls.len(); // Now we probe the right hand side for each left hand side. - Ok(POOL - .install(|| { - split_by_left - .into_par_iter() - .zip(offsets) - .flat_map(|(by_left, offset)| { - let mut results = Vec::with_capacity(by_left.len()); - let mut group_states: PlHashMap = - PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); - - let by_left_chunk = by_left.downcast_iter().next().unwrap(); - for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { - let Some(by_left_k) = opt_by_left_k else { - results.push(None); - continue; - }; - let idx_left = (rel_idx_left + offset) as IdxSize; - let Some(left_val) = left_val_arr.get(idx_left as usize) else { - results.push(None); - continue; - }; - - let group_probe_table = unsafe { - hash_tbls - .get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) - }; - let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { - results.push(None); - continue; - }; + let out = split_by_left + .into_par_iter() + .zip(offsets) + .flat_map(|(by_left, offset)| { + let mut results = Vec::with_capacity(by_left.len()); + let mut group_states: PlHashMap = + PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + let by_left_chunk = by_left.downcast_iter().next().unwrap(); + for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { + let Some(by_left_k) = opt_by_left_k else { + results.push(None); + continue; + }; + let by_left_k = by_left_k.to_total_ord(); + let idx_left = (rel_idx_left + offset) as IdxSize; + let Some(left_val) = left_val_arr.get(idx_left as usize) else { + results.push(None); + continue; + }; + + let group_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) + }; + let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else { + results.push(None); + continue; + }; + + results.push(asof_in_group::( + left_val, + right_val_arr, + right_grp_idxs.as_slice(), + &mut group_states, + &filter, + )); + } + results + }); - results.push(asof_in_group::( - left_val, - right_val_arr, - right_grp_idxs.as_slice(), - &mut group_states, - &filter, - )); - } - results - }) - }) - .collect()) + Ok(POOL.install(|| out.collect())) } fn asof_join_by_binary( @@ -329,7 +328,15 @@ where asof_join_by_binary::(left_by, right_by, left_asof, right_asof, filter) }, _ => { - if left_by_s.bit_repr_is_large() { + if left_by_s.dtype().is_float() { + with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { + let left_by: &ChunkedArray<$T> = left_by_s.as_ref().as_ref().as_ref(); + let right_by: &ChunkedArray<$T> = right_by_s.as_ref().as_ref().as_ref(); + asof_join_by_numeric::( + left_by, right_by, left_asof, right_asof, filter, + )? + }) + } else if left_by_s.bit_repr_is_large() { let left_by = left_by_s.bit_repr_large(); let right_by = right_by_s.bit_repr_large(); asof_join_by_numeric::( @@ -560,7 +567,7 @@ pub trait AsofJoinBy: IntoDf { .filter(|s| !drop_these.contains(&s.name())) .cloned() .collect(); - let proj_other_df = DataFrame::new_no_checks(cols); + let proj_other_df = unsafe { DataFrame::new_no_checks(cols) }; let left = self_df.clone(); let right_join_tuples = &*right_join_tuples; diff --git a/crates/polars-ops/src/frame/join/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs index 1bc596960b2c..1e1b1bcba497 100644 --- a/crates/polars-ops/src/frame/join/cross_join.rs +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -1,6 +1,4 @@ -use polars_core::series::IsSorted; -use polars_core::utils::{concat_df_unchecked, slice_offsets, CustomIterTools, NoNull}; -use polars_core::POOL; +use polars_core::utils::{concat_df_unchecked, CustomIterTools, NoNull}; use smartstring::alias::String as SmartString; use super::*; @@ -70,7 +68,7 @@ pub trait CrossJoin: IntoDf { // right take idx: 012301230123 let create_left_df = || { - // Safety: + // SAFETY: // take left is in bounds unsafe { df_self.take_unchecked(&take_left(total_rows, n_rows_right, slice)) } }; @@ -80,7 +78,7 @@ pub trait CrossJoin: IntoDf { // many times, these are atomic operations // so we choose a different strategy at > 100 rows (arbitrarily small number) if n_rows_left > 100 || slice.is_some() { - // Safety: + // SAFETY: // take right is in bounds unsafe { other.take_unchecked(&take_right(total_rows, n_rows_right, slice)) } } else { @@ -120,7 +118,7 @@ pub trait CrossJoin: IntoDf { Ok(l_df) } - /// Creates the cartesian product from both frames, preserves the order of the left keys. + /// Creates the Cartesian product from both frames, preserves the order of the left keys. fn cross_join( &self, other: &DataFrame, diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index a6aaa25fca95..934b27caad6a 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -1,4 +1,5 @@ -use std::borrow::Cow; +#[cfg(feature = "chunked_ids")] +use polars_utils::index::ChunkId; use super::*; use crate::series::coalesce_series; @@ -93,7 +94,9 @@ pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> V let mut vals = Vec::with_capacity(len); for (chunk_i, chunk) in chunks.iter().enumerate() { - vals.extend((0..chunk.len()).map(|array_i| [chunk_i as IdxSize, array_i as IdxSize])) + vals.extend( + (0..chunk.len()).map(|array_i| ChunkId::store(chunk_i as IdxSize, array_i as IdxSize)), + ) } vals diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index ad25b39a6d40..891164dcccfe 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -7,14 +7,12 @@ mod single_keys_outer; #[cfg(feature = "semi_anti_join")] mod single_keys_semi_anti; pub(super) mod sort_merge; - use arrow::array::ArrayRef; pub use multiple_keys::private_left_join_multiple_keys; pub(super) use multiple_keys::*; -#[cfg(any(feature = "chunked_ids", feature = "semi_anti_join"))] -use polars_core::utils::slice_slice; -use polars_core::utils::{_set_partition_size, slice_offsets, split_ca}; +use polars_core::utils::{_set_partition_size, split_ca}; use polars_core::POOL; +use polars_utils::index::ChunkId; pub(super) use single_keys::*; #[cfg(feature = "asof_join")] pub(super) use single_keys_dispatch::prepare_bytes; @@ -27,6 +25,8 @@ use single_keys_semi_anti::*; pub use sort_merge::*; pub use super::*; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::gather::chunked::DfTake; pub fn default_join_ids() -> ChunkJoinOptIds { #[cfg(feature = "chunked_ids")] @@ -54,8 +54,6 @@ macro_rules! det_hash_prone_order { use arrow::legacy::conversion::primitive_to_vec; pub(super) use det_hash_prone_order; -use crate::frame::join::general::coalesce_outer_join; - pub trait JoinDispatch: IntoDf { /// # Safety /// Join tuples must be in bounds @@ -233,7 +231,7 @@ pub trait JoinDispatch: IntoDf { _check_categorical_src(s_left.dtype(), s_right.dtype())?; let idx = s_left.hash_join_semi_anti(s_right, anti); - // Safety: + // SAFETY: // indices are in bounds Ok(unsafe { ca_self._finish_anti_semi_join(&idx, slice) }) } diff --git a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs index a493d43f6ff0..7e2c737da070 100644 --- a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs @@ -1,12 +1,10 @@ use arrow::array::{MutablePrimitiveArray, PrimitiveArray}; -use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; -use polars_core::hashing::{ - populate_multiple_key_hashmap, IdBuildHasher, IdxHash, _HASHMAP_INIT_SIZE, -}; -use polars_core::utils::{_set_partition_size, split_df}; -use polars_core::POOL; +use polars_core::hashing::{populate_multiple_key_hashmap, IdBuildHasher, IdxHash}; +use polars_core::utils::split_df; use polars_utils::hashing::hash_to_partition; +use polars_utils::idx_vec::IdxVec; +use polars_utils::unitvec; use super::*; @@ -31,7 +29,7 @@ pub(crate) unsafe fn compare_df_rows2( pub(crate) fn create_probe_table( hashes: &[UInt64Chunked], keys: &DataFrame, -) -> Vec, IdBuildHasher>> { +) -> Vec> { let n_partitions = _set_partition_size(); // We will create a hashtable in every thread. @@ -41,7 +39,7 @@ pub(crate) fn create_probe_table( (0..n_partitions) .into_par_iter() .map(|part_no| { - let mut hash_tbl: HashMap, IdBuildHasher> = + let mut hash_tbl: HashMap = HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); let mut offset = 0; @@ -59,7 +57,7 @@ pub(crate) fn create_probe_table( idx, *h, keys, - || vec![idx], + || unitvec![idx], |v| v.push(idx), ) } @@ -78,7 +76,7 @@ pub(crate) fn create_probe_table( fn create_build_table_outer( hashes: &[UInt64Chunked], keys: &DataFrame, -) -> Vec), IdBuildHasher>> { +) -> Vec> { // Outer join equivalent of create_build_table() adds a bool in the hashmap values for tracking // whether a value in the hash table has already been matched to a value in the probe hashes. let n_partitions = _set_partition_size(); @@ -86,47 +84,46 @@ fn create_build_table_outer( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap), IdBuildHasher> = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (false, vec![idx]), - |v| v.1.push(idx), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap( + &mut hash_tbl, + idx, + *h, + keys, + || (false, unitvec![idx]), + |v| v.1.push(idx), + ) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) #[allow(clippy::too_many_arguments)] fn probe_inner( probe_hashes: &UInt64Chunked, - hash_tbls: &[HashMap, IdBuildHasher>], + hash_tbls: &[HashMap], results: &mut Vec<(IdxSize, IdxSize)>, local_offset: usize, n_tables: usize, @@ -146,7 +143,7 @@ fn probe_inner( let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { let idx_b = idx_hash.idx; - // Safety: + // SAFETY: // indices in a join operation are always in bounds. unsafe { compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) } }); @@ -249,8 +246,8 @@ pub fn private_left_join_multiple_keys( chunk_mapping_right: Option<&[ChunkId]>, join_nulls: bool, ) -> LeftJoinIds { - let mut a = DataFrame::new_no_checks(_to_physical_and_bit_repr(a.get_columns())); - let mut b = DataFrame::new_no_checks(_to_physical_and_bit_repr(b.get_columns())); + let mut a = unsafe { DataFrame::new_no_checks(_to_physical_and_bit_repr(a.get_columns())) }; + let mut b = unsafe { DataFrame::new_no_checks(_to_physical_and_bit_repr(b.get_columns())) }; _left_join_multiple_keys( &mut a, &mut b, @@ -312,7 +309,7 @@ pub fn _left_join_multiple_keys( let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { let idx_b = idx_hash.idx; - // Safety: + // SAFETY: // indices in a join operation are always in bounds. unsafe { compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) @@ -358,40 +355,32 @@ pub(crate) fn create_build_table_semi_anti( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (), - |_| (), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap(&mut hash_tbl, idx, *h, keys, || (), |_| ()) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -421,46 +410,43 @@ pub(crate) fn semi_anti_join_multiple_keys_impl<'a>( // next we probe the other relation // code duplication is because we want to only do the swap check once - POOL.install(move || { - probe_hashes - .into_par_iter() - .zip(offsets) - .flat_map(move |(probe_hashes, offset)| { - // local reference - let hash_tbls = &hash_tbls; - let mut results = - Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); - let local_offset = offset; - - let mut idx_a = local_offset as IdxSize; - for probe_hashes in probe_hashes.data_views() { - for &h in probe_hashes { - // probe table that contains the hashed value - let current_probe_table = - unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; - - let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { - let idx_b = idx_hash.idx; - // Safety: - // indices in a join operation are always in bounds. - unsafe { - compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) - } - }); - - match entry { - // left and right matches - Some((_, _)) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), + probe_hashes + .into_par_iter() + .zip(offsets) + .flat_map(move |(probe_hashes, offset)| { + // local reference + let hash_tbls = &hash_tbls; + let mut results = Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); + let local_offset = offset; + + let mut idx_a = local_offset as IdxSize; + for probe_hashes in probe_hashes.data_views() { + for &h in probe_hashes { + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; + + let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { + let idx_b = idx_hash.idx; + // SAFETY: + // indices in a join operation are always in bounds. + unsafe { + compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) } - idx_a += 1; + }); + + match entry { + // left and right matches + Some((_, _)) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), } + idx_a += 1; } + } - results - }) - }) + results + }) } #[cfg(feature = "semi_anti_join")] @@ -469,10 +455,10 @@ pub fn _left_anti_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -481,10 +467,10 @@ pub fn _left_semi_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) @@ -492,7 +478,7 @@ pub fn _left_semi_multiple_keys( #[allow(clippy::type_complexity)] fn probe_outer( probe_hashes: &[UInt64Chunked], - hash_tbls: &mut [HashMap), IdBuildHasher>], + hash_tbls: &mut [HashMap], results: &mut ( MutablePrimitiveArray, MutablePrimitiveArray, @@ -531,7 +517,7 @@ fn probe_outer( .raw_entry_mut() .from_hash(h, |idx_hash| { let idx_b = idx_hash.idx; - // Safety: + // SAFETY: // indices in a join operation are always in bounds. unsafe { compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs index f18a92978037..38b59c2d7454 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs @@ -2,6 +2,8 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +use polars_utils::unitvec; use super::*; @@ -11,9 +13,13 @@ use super::*; // Use a small element per thread threshold for debugging/testing purposes. const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 }; -pub(crate) fn build_tables(keys: Vec, join_nulls: bool) -> Vec> +pub(crate) fn build_tables( + keys: Vec, + join_nulls: bool, +) -> Vec::TotalOrdItem, IdxVec>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, I: IntoIterator + Send + Sync + Clone, { // FIXME: change interface to split the input here, instead of taking @@ -27,10 +33,11 @@ where // Don't bother parallelizing anything for small inputs. if num_keys_est < 2 * MIN_ELEMS_PER_THREAD { - let mut hm: PlHashMap = PlHashMap::new(); + let mut hm: PlHashMap = PlHashMap::new(); let mut offset = 0; for it in keys { for k in it { + let k = k.to_total_ord(); if !k.is_null() || join_nulls { hm.entry(k).or_default().push(offset); } @@ -48,6 +55,7 @@ where .map(|key_portion| { let mut partition_sizes = vec![0; n_partitions]; for key in key_portion.clone() { + let key = key.to_total_ord(); let p = hash_to_partition(key.dirty_hash(), n_partitions); unsafe { *partition_sizes.get_unchecked_mut(p) += 1; @@ -84,7 +92,7 @@ where } // Scatter values into partitions. - let mut scatter_keys: Vec = Vec::with_capacity(num_keys); + let mut scatter_keys: Vec = Vec::with_capacity(num_keys); let mut scatter_idxs: Vec = Vec::with_capacity(num_keys); let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) }; let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) }; @@ -95,6 +103,7 @@ where let mut partition_offsets = per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec(); for (i, key) in key_portion.into_iter().enumerate() { + let key = key.to_total_ord(); unsafe { let p = hash_to_partition(key.dirty_hash(), n_partitions); let off = partition_offsets.get_unchecked_mut(p); @@ -123,7 +132,8 @@ where let partition_range = partition_offsets[p]..partition_offsets[p + 1]; let full_size = partition_range.len(); let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64); - let mut hm: PlHashMap = PlHashMap::with_capacity(conservative_size); + let mut hm: PlHashMap = + PlHashMap::with_capacity(conservative_size); unsafe { for i in partition_range { @@ -141,8 +151,7 @@ where o.get_mut().push(idx as IdxSize); }, Entry::Vacant(v) => { - let mut iv = IdxVec::new(); - iv.push(idx as IdxSize); + let iv = unitvec![idx as IdxSize]; v.insert(iv); }, }; @@ -160,8 +169,6 @@ where pub(super) fn probe_to_offsets(probe: &[I]) -> Vec where I: IntoIterator + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, { probe .iter() diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index ac562d41bc09..9468ac483d3d 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -1,7 +1,8 @@ use arrow::array::PrimitiveArray; -use num_traits::NumCast; +use polars_core::with_match_physical_float_polars_type; use polars_utils::hashing::DirtyHash; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; use crate::series::SeriesSealed; @@ -19,31 +20,29 @@ pub trait SeriesJoin: SeriesSealed + Sized { validate.validate_probe(&lhs, &rhs, false)?; use DataType::*; - match lhs.dtype() { - String => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); - lhs.hash_join_left(&rhs, JoinValidation::ManyToMany, join_nulls) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); - let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); - let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); - hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) - }, - _ => { - if s_self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_left(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_left(&lhs, &rhs, validate, join_nulls) - } - }, + if matches!(lhs.dtype(), String | Binary) { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_left(lhs, rhs, validate, join_nulls) + }) + } else if s_self.bit_repr_is_large() { + let lhs = lhs.bit_repr_large(); + let rhs = rhs.bit_repr_large(); + num_group_join_left(&lhs, &rhs, validate, join_nulls) + } else { + let lhs = lhs.bit_repr_small(); + let rhs = rhs.bit_repr_small(); + num_group_join_left(&lhs, &rhs, validate, join_nulls) } } @@ -53,35 +52,33 @@ pub trait SeriesJoin: SeriesSealed + Sized { let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); use DataType::*; - match lhs.dtype() { - String => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); - lhs.hash_join_semi_anti(&rhs, anti) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); - let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); - let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); - if anti { - hash_join_tuples_left_anti(lhs, rhs) - } else { - hash_join_tuples_left_semi(lhs, rhs) - } - }, - _ => { - if s_self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_anti_semi(&lhs, &rhs, anti) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_anti_semi(&lhs, &rhs, anti) - } - }, + if matches!(lhs.dtype(), String | Binary) { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + if anti { + hash_join_tuples_left_anti(lhs, rhs) + } else { + hash_join_tuples_left_semi(lhs, rhs) + } + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_anti_semi(lhs, rhs, anti) + }) + } else if s_self.bit_repr_is_large() { + let lhs = lhs.bit_repr_large(); + let rhs = rhs.bit_repr_large(); + num_group_join_anti_semi(&lhs, &rhs, anti) + } else { + let lhs = lhs.bit_repr_small(); + let rhs = rhs.bit_repr_small(); + num_group_join_anti_semi(&lhs, &rhs, anti) } } @@ -97,34 +94,32 @@ pub trait SeriesJoin: SeriesSealed + Sized { validate.validate_probe(&lhs, &rhs, true)?; use DataType::*; - match lhs.dtype() { - String => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); - lhs.hash_join_inner(&rhs, JoinValidation::ManyToMany, join_nulls) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); - let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); - let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); - Ok(( - hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, - !swapped, - )) - }, - _ => { - if s_self.bit_repr_is_large() { - let lhs = s_self.bit_repr_large(); - let rhs = other.bit_repr_large(); - group_join_inner::(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = s_self.bit_repr_small(); - let rhs = other.bit_repr_small(); - group_join_inner::(&lhs, &rhs, validate, join_nulls) - } - }, + if matches!(lhs.dtype(), String | Binary) { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + Ok(( + hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, + !swapped, + )) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + group_join_inner::<$T>(lhs, rhs, validate, join_nulls) + }) + } else if s_self.bit_repr_is_large() { + let lhs = s_self.bit_repr_large(); + let rhs = other.bit_repr_large(); + group_join_inner::(&lhs, &rhs, validate, join_nulls) + } else { + let lhs = s_self.bit_repr_small(); + let rhs = other.bit_repr_small(); + group_join_inner::(&lhs, &rhs, validate, join_nulls) } } @@ -139,31 +134,29 @@ pub trait SeriesJoin: SeriesSealed + Sized { validate.validate_probe(&lhs, &rhs, true)?; use DataType::*; - match lhs.dtype() { - String => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); - lhs.hash_join_outer(&rhs, JoinValidation::ManyToMany, join_nulls) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); - let lhs = lhs.iter().collect::>(); - let rhs = rhs.iter().collect::>(); - hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) - }, - _ => { - if s_self.bit_repr_is_large() { - let lhs = s_self.bit_repr_large(); - let rhs = other.bit_repr_large(); - hash_join_outer(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = s_self.bit_repr_small(); - let rhs = other.bit_repr_small(); - hash_join_outer(&lhs, &rhs, validate, join_nulls) - } - }, + if matches!(lhs.dtype(), String | Binary) { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); + let lhs = lhs.iter().collect::>(); + let rhs = rhs.iter().collect::>(); + hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + hash_join_outer(lhs, rhs, validate, join_nulls) + }) + } else if s_self.bit_repr_is_large() { + let lhs = s_self.bit_repr_large(); + let rhs = other.bit_repr_large(); + hash_join_outer(&lhs, &rhs, validate, join_nulls) + } else { + let lhs = s_self.bit_repr_small(); + let rhs = other.bit_repr_small(); + hash_join_outer(&lhs, &rhs, validate, join_nulls) } } } @@ -193,7 +186,10 @@ fn group_join_inner( where T: PolarsDataType, for<'a> &'a T::Array: IntoIterator>>, - for<'a> T::Physical<'a>: Hash + Eq + Send + DirtyHash + Copy + Send + Sync + IsNull, + for<'a> T::Physical<'a>: + Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + for<'a> as ToTotalOrd>::TotalOrdItem: + Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let n_threads = POOL.current_num_threads(); let (a, b, swapped) = det_hash_prone_order!(left, right); @@ -275,9 +271,11 @@ fn num_group_join_left( join_nulls: bool, ) -> PolarsResult where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash + IsNull, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + T::Native: DirtyHash + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); @@ -332,8 +330,9 @@ fn hash_join_outer( join_nulls: bool, ) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> where - T: PolarsIntegerType + Sync, - T::Native: Eq + Hash + NumCast, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + IsNull, { let (a, b, swapped) = det_hash_prone_order!(ca_in, other); @@ -375,7 +374,7 @@ pub fn prepare_bytes<'a>( been_split .par_iter() .map(|ca| { - ca.into_iter() + ca.iter() .map(|opt_b| { let hash = hb.hash_one(opt_b); BytesHash::new(opt_b, hash) @@ -427,9 +426,10 @@ fn num_group_join_anti_semi( anti: bool, ) -> Vec where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index bc5e5d4acdce..58bdd286a814 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -4,23 +4,25 @@ use polars_utils::idx_vec::IdxVec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; pub(super) fn probe_inner( probe: I, - hash_tbls: &[PlHashMap], + hash_tbls: &[PlHashMap<::TotalOrdItem, IdxVec>], results: &mut Vec<(IdxSize, IdxSize)>, local_offset: IdxSize, n_tables: usize, swap_fn: F, ) where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, I: IntoIterator, - // ::IntoIter: TrustedLen, F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize), { probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = idx_a + local_offset; // probe table that contains the hashed value let current_probe_table = @@ -45,8 +47,8 @@ pub(super) fn hash_join_tuples_inner( ) -> PolarsResult<(Vec, Vec)> where I: IntoIterator + Send + Sync + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { // NOTE: see the left join for more elaborate comments // first we hash one relation diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs index 51956d41585d..7bdbb5dcaade 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -1,6 +1,7 @@ use polars_core::utils::flatten::flatten_par; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -112,7 +113,8 @@ pub(super) fn hash_join_tuples_left( where I: IntoIterator, ::IntoIter: Send + Sync + Clone, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); @@ -147,6 +149,7 @@ where let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap()); probe.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = unsafe { diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs index 61c48eefa934..f2e0f21ad8f7 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -1,7 +1,10 @@ use arrow::array::{MutablePrimitiveArray, PrimitiveArray}; use arrow::legacy::utils::CustomIterTools; use polars_utils::hashing::hash_to_partition; +use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +use polars_utils::unitvec; use super::*; @@ -12,7 +15,8 @@ pub(crate) fn create_hash_and_keys_threaded_vectorized( where I: IntoIterator + Send, I::IntoIter: TrustedLen, - T: Send + Hash + Eq, + T: TotalHash + TotalEq + Send + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let build_hasher = build_hasher.unwrap_or_default(); let hashes = POOL.install(|| { @@ -21,7 +25,7 @@ where .map(|iter| { // create hashes and keys iter.into_iter() - .map(|val| (build_hasher.hash_one(&val), val)) + .map(|val| (build_hasher.hash_one(&val.to_total_ord()), val)) .collect_trusted::>() }) .collect() @@ -31,10 +35,11 @@ where pub(crate) fn prepare_hashed_relation_threaded( iters: Vec, -) -> Vec)>> +) -> Vec::TotalOrdItem, (bool, IdxVec)>> where I: Iterator + Send + TrustedLen, - T: Send + Hash + Eq + Sync + Copy, + T: Send + Sync + TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq, { let n_partitions = _set_partition_size(); let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); @@ -48,7 +53,7 @@ where .map(|partition_no| { let build_hasher = build_hasher.clone(); let hashes_and_keys = &hashes_and_keys; - let mut hash_tbl: PlHashMap)> = + let mut hash_tbl: PlHashMap = PlHashMap::with_hasher(build_hasher); let mut offset = 0; @@ -58,6 +63,7 @@ where .iter() .enumerate() .for_each(|(idx, (h, k))| { + let k = k.to_total_ord(); let idx = idx as IdxSize; // partition hashes by thread no. // So only a part of the hashes go to this hashmap @@ -66,11 +72,11 @@ where let entry = hash_tbl .raw_entry_mut() // uses the key to check equality to find and entry - .from_key_hashed_nocheck(*h, k); + .from_key_hashed_nocheck(*h, &k); match entry { RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(*h, *k, (false, vec![idx])); + entry.insert_hashed_nocheck(*h, k, (false, unitvec![idx])); }, RawEntryMut::Occupied(mut entry) => { let (_k, v) = entry.get_key_value_mut(); @@ -92,7 +98,7 @@ where #[allow(clippy::too_many_arguments)] fn probe_outer( probe_hashes: &[Vec<(u64, T)>], - hash_tbls: &mut [PlHashMap)>], + hash_tbls: &mut [PlHashMap<::TotalOrdItem, (bool, IdxVec)>], results: &mut ( MutablePrimitiveArray, MutablePrimitiveArray, @@ -106,7 +112,8 @@ fn probe_outer( swap_fn_drain: H, join_nulls: bool, ) where - T: Send + Hash + Eq + Sync + Copy + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + IsNull, // idx_a, idx_b -> ... F: Fn(IdxSize, IdxSize) -> (Option, Option), // idx_a -> ... @@ -118,6 +125,7 @@ fn probe_outer( let mut idx_a = 0; for probe_hashes in probe_hashes { for (h, key) in probe_hashes { + let key = key.to_total_ord(); let h = *h; // probe table that contains the hashed value let current_probe_table = @@ -125,7 +133,7 @@ fn probe_outer( let entry = current_probe_table .raw_entry_mut() - .from_key_hashed_nocheck(h, key); + .from_key_hashed_nocheck(h, &key); match entry { // match and remove @@ -180,7 +188,8 @@ where J: IntoIterator, ::IntoIter: TrustedLen + Send, ::IntoIter: TrustedLen + Send, - T: Hash + Eq + Copy + Sync + Send + IsNull, + T: Send + Sync + TotalHash + TotalEq + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs index 93268036c43d..57196e86632d 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -1,11 +1,15 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; /// Only keeps track of membership in right table -pub(super) fn create_probe_table_semi_anti(keys: Vec) -> Vec> +pub(super) fn create_probe_table_semi_anti( + keys: Vec, +) -> Vec::TotalOrdItem>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, I: IntoIterator + Copy + Send + Sync, { let n_partitions = _set_partition_size(); @@ -13,29 +17,31 @@ where // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|partition_no| { - let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); - for keys in &keys { - keys.into_iter().for_each(|k| { - if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { - hash_tbl.insert(k); - } - }); - } - hash_tbl - }) - }) - .collect() + let par_iter = (0..n_partitions).into_par_iter().map(|partition_no| { + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); + for keys in &keys { + keys.into_iter().for_each(|k| { + let k = k.to_total_ord(); + if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { + hash_tbl.insert(k); + } + }); + } + hash_tbl + }); + POOL.install(|| par_iter.collect()) } -pub(super) fn semi_anti_impl( +/// Construct a ParallelIterator, but doesn't iterate it. This means the caller +/// context (or wherever it gets iterated) should be in POOL.install. +fn semi_anti_impl( probe: Vec, build: Vec, ) -> impl ParallelIterator where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { // first we hash one relation let hash_sets = create_probe_table_semi_anti(build); @@ -46,60 +52,61 @@ where let n_tables = hash_sets.len(); // next we probe the other relation - POOL.install(move || { - probe - .into_par_iter() - .zip(offsets) - // probes_hashes: Vec processed by this thread - // offset: offset index - .flat_map(move |(probe, offset)| { - // local reference - let hash_sets = &hash_sets; - let probe_iter = probe.into_iter(); + // This is not wrapped in POOL.install because it is not being iterated here + probe + .into_par_iter() + .zip(offsets) + // probes_hashes: Vec processed by this thread + // offset: offset index + .flat_map(move |(probe, offset)| { + // local reference + let hash_sets = &hash_sets; + let probe_iter = probe.into_iter(); - // assume the result tuples equal length of the no. of hashes processed by this thread. - let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); + // assume the result tuples equal length of the no. of hashes processed by this thread. + let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); - probe_iter.enumerate().for_each(|(idx_a, k)| { - let idx_a = (idx_a + offset) as IdxSize; - // probe table that contains the hashed value - let current_probe_table = unsafe { - hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) - }; + probe_iter.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); + let idx_a = (idx_a + offset) as IdxSize; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) }; - // we already hashed, so we don't have to hash again. - let value = current_probe_table.get(&k); + // we already hashed, so we don't have to hash again. + let value = current_probe_table.get(&k); - match value { - // left and right matches - Some(_) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), - } - }); - results - }) - }) + match value { + // left and right matches + Some(_) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), + } + }); + results + }) } pub(super) fn hash_join_tuples_left_anti(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } pub(super) fn hash_join_tuples_left_semi(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } diff --git a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs index d9b849ce1e59..6d97ee4735f4 100644 --- a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs +++ b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs @@ -23,13 +23,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::left::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::left::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>(); @@ -96,13 +95,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>(); diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index 9a4267e76135..fc687aaa623f 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -43,7 +43,7 @@ pub fn _merge_sorted_dfs( }) .collect(); - Ok(DataFrame::new_no_checks(new_columns)) + Ok(unsafe { DataFrame::new_no_checks(new_columns) }) } fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> Series { @@ -120,7 +120,7 @@ where } }); - // Safety: length is correct + // SAFETY: length is correct unsafe { iter.trust_my_length(total_len).collect_trusted() } } diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 90b6e9ef9370..5c9eeb577fc7 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -9,7 +9,6 @@ mod hash_join; #[cfg(feature = "merge_sorted")] mod merge_sorted; -#[cfg(feature = "chunked_ids")] use std::borrow::Cow; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; @@ -34,7 +33,9 @@ pub use merge_sorted::_merge_sorted_dfs; use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; use polars_core::prelude::*; pub(super) use polars_core::series::IsSorted; -use polars_core::utils::{_to_physical_and_bit_repr, slice_offsets, slice_slice}; +#[allow(unused_imports)] +use polars_core::utils::slice_slice; +use polars_core::utils::{_to_physical_and_bit_repr, slice_offsets}; use polars_core::POOL; use polars_utils::hashing::BytesHash; use rayon::prelude::*; @@ -273,8 +274,8 @@ pub trait DataFrameJoinOps: IntoDf { // Multiple keys. match args.how { JoinType::Inner => { - let left = DataFrame::new_no_checks(selected_left_physical); - let right = DataFrame::new_no_checks(selected_right_physical); + let left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let (mut left, mut right, swap) = det_hash_prone_order!(left, right); let (join_idx_left, join_idx_right) = _inner_join_multiple_keys(&mut left, &mut right, swap, args.join_nulls); @@ -287,7 +288,7 @@ pub trait DataFrameJoinOps: IntoDf { } let (df_left, df_right) = POOL.join( - // safety: join indices are known to be in bounds + // SAFETY: join indices are known to be in bounds || unsafe { left_df._create_left_df_from_slice(join_idx_left, false, !swap) }, || unsafe { // remove join columns @@ -298,8 +299,8 @@ pub trait DataFrameJoinOps: IntoDf { _finish_join(df_left, df_right, args.suffix.as_deref()) }, JoinType::Left => { - let mut left = DataFrame::new_no_checks(selected_left_physical); - let mut right = DataFrame::new_no_checks(selected_right_physical); + let mut left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let mut right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; if let Some((offset, len)) = args.slice { left = left.slice(offset, len); @@ -309,8 +310,8 @@ pub trait DataFrameJoinOps: IntoDf { left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args) }, JoinType::Outer { .. } => { - let df_left = DataFrame::new_no_checks(selected_left_physical); - let df_right = DataFrame::new_no_checks(selected_right_physical); + let df_left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let df_right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let (mut left, mut right, swap) = det_hash_prone_order!(df_left, df_right); let (mut join_idx_l, mut join_idx_r) = @@ -354,15 +355,15 @@ pub trait DataFrameJoinOps: IntoDf { ), #[cfg(feature = "semi_anti_join")] JoinType::Anti | JoinType::Semi => { - let mut left = DataFrame::new_no_checks(selected_left_physical); - let mut right = DataFrame::new_no_checks(selected_right_physical); + let mut left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let mut right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let idx = if matches!(args.how, JoinType::Anti) { _left_anti_multiple_keys(&mut left, &mut right, args.join_nulls) } else { _left_semi_multiple_keys(&mut left, &mut right, args.join_nulls) }; - // Safety: + // SAFETY: // indices are in bounds Ok(unsafe { left_df._finish_anti_semi_join(&idx, args.slice) }) }, @@ -497,7 +498,7 @@ trait DataFrameJoinOpsPrivate: IntoDf { } let (df_left, df_right) = POOL.join( - // safety: join indices are known to be in bounds + // SAFETY: join indices are known to be in bounds || unsafe { left_df._create_left_df_from_slice(join_tuples_left, false, sorted) }, || unsafe { other diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 89857eea6d3d..cec9ddd01cdb 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -30,7 +30,7 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { (dt @ DataType::Categorical(Some(rev_map), ordering), _) | (dt @ DataType::Enum(Some(rev_map), ordering), _) => { let cats = s.u32().unwrap().clone(); - // safety: + // SAFETY: // the rev-map comes from these categoricals unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -82,27 +82,23 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { /// # Note /// Polars'/arrow memory is not ideal for transposing operations like pivots. /// If you have a relatively large table, consider using a group_by over a pivot. -pub fn pivot( +pub fn pivot( pivot_df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_fn: Option, separator: Option<&str>, ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { - let values = values - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); let index = index .into_iter() .map(|s| s.as_ref().to_string()) @@ -111,11 +107,12 @@ where .into_iter() .map(|s| s.as_ref().to_string()) .collect::>(); + let values = get_values_columns(pivot_df, &index, &columns, values); pivot_impl( pivot_df, - &values, &index, &columns, + &values, agg_fn, sort_columns, false, @@ -128,27 +125,23 @@ where /// # Note /// Polars'/arrow memory is not ideal for transposing operations like pivots. /// If you have a relatively large table, consider using a group_by over a pivot. -pub fn pivot_stable( +pub fn pivot_stable( pivot_df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_fn: Option, separator: Option<&str>, ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { - let values = values - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); let index = index .into_iter() .map(|s| s.as_ref().to_string()) @@ -157,12 +150,12 @@ where .into_iter() .map(|s| s.as_ref().to_string()) .collect::>(); - + let values = get_values_columns(pivot_df, &index, &columns, values); pivot_impl( pivot_df, - &values, &index, &columns, + &values, agg_fn, sort_columns, true, @@ -170,16 +163,41 @@ where ) } +/// Determine `values` columns, which is optional in `pivot` calls. +/// +/// If not specified (i.e. is `None`), use all remaining columns in the +/// `DataFrame` after `index` and `columns` have been excluded. +fn get_values_columns( + df: &DataFrame, + index: &[String], + columns: &[String], + values: Option, +) -> Vec +where + I: IntoIterator, + S: AsRef, +{ + match values { + Some(v) => v.into_iter().map(|s| s.as_ref().to_string()).collect(), + None => df + .get_column_names() + .into_iter() + .map(|c| c.to_string()) + .filter(|c| !(index.contains(c) | columns.contains(c))) + .collect(), + } +} + #[allow(clippy::too_many_arguments)] fn pivot_impl( pivot_df: &DataFrame, - // these columns will be aggregated in the nested group_by - values: &[String], // keys of the first group_by operation index: &[String], // these columns will be used for a nested group_by // the rows of this nested group_by will be pivoted as header column values columns: &[String], + // these columns will be aggregated in the nested group_by + values: &[String], // aggregation function agg_fn: Option, sort_columns: bool, @@ -187,120 +205,161 @@ fn pivot_impl( // used as separator/delimiter in generated column names. separator: Option<&str>, ) -> PolarsResult { - let sep = separator.unwrap_or("_"); polars_ensure!(!index.is_empty(), ComputeError: "index cannot be zero length"); + polars_ensure!(!columns.is_empty(), ComputeError: "columns cannot be zero length"); + if !stable { + println!("unstable pivot not yet supported, using stable pivot"); + }; + if columns.len() > 1 { + let schema = Arc::new(pivot_df.schema()); + let binding = pivot_df.select_with_schema(columns, &schema)?; + let fields = binding.get_columns(); + let column = format!("{{\"{}\"}}", columns.join("\",\"")); + if schema.contains(column.as_str()) { + polars_bail!(ComputeError: "cannot use column name {column} that \ + already exists in the DataFrame. Please rename it prior to calling `pivot`.") + } + let columns_struct = StructChunked::new(&column, fields).unwrap().into_series(); + let mut binding = pivot_df.clone(); + let pivot_df = unsafe { binding.with_column_unchecked(columns_struct) }; + pivot_impl_single_column( + pivot_df, + index, + &column, + values, + agg_fn, + sort_columns, + separator, + ) + } else { + pivot_impl_single_column( + pivot_df, + index, + unsafe { columns.get_unchecked(0) }, + values, + agg_fn, + sort_columns, + separator, + ) + } +} +fn pivot_impl_single_column( + pivot_df: &DataFrame, + index: &[String], + column: &str, + values: &[String], + agg_fn: Option, + sort_columns: bool, + separator: Option<&str>, +) -> PolarsResult { + let sep = separator.unwrap_or("_"); let mut final_cols = vec![]; - let mut count = 0; let out: PolarsResult<()> = POOL.install(|| { - for column_column_name in columns { - let mut group_by = index.to_vec(); - group_by.push(column_column_name.clone()); - - let groups = pivot_df.group_by_stable(group_by)?.take_groups(); + let mut group_by = index.to_vec(); + group_by.push(column.to_string()); - // these are the row locations - if !stable { - println!("unstable pivot not yet supported, using stable pivot"); - }; + let groups = pivot_df.group_by_stable(group_by)?.take_groups(); - let (col, row) = POOL.join( - || positioning::compute_col_idx(pivot_df, column_column_name, &groups), - || positioning::compute_row_idx(pivot_df, index, &groups, count), - ); - let (col_locations, column_agg) = col?; - let (row_locations, n_rows, mut row_index) = row?; + let (col, row) = POOL.join( + || positioning::compute_col_idx(pivot_df, column, &groups), + || positioning::compute_row_idx(pivot_df, index, &groups, count), + ); + let (col_locations, column_agg) = col?; + let (row_locations, n_rows, mut row_index) = row?; - for value_col_name in values { - let value_col = pivot_df.column(value_col_name)?; + for value_col_name in values { + let value_col = pivot_df.column(value_col_name)?; - use PivotAgg::*; - let value_agg = unsafe { - match &agg_fn { - None => match value_col.len() > groups.len() { - true => polars_bail!(ComputeError: "found multiple elements in the same group, please specify an aggregation function"), - false => value_col.agg_first(&groups), - } - Some(agg_fn) => match agg_fn { - Sum => value_col.agg_sum(&groups), - Min => value_col.agg_min(&groups), - Max => value_col.agg_max(&groups), - Last => value_col.agg_last(&groups), - First => value_col.agg_first(&groups), - Mean => value_col.agg_mean(&groups), - Median => value_col.agg_median(&groups), - Count => groups.group_count().into_series(), - Expr(ref expr) => { - let name = expr.root_name()?; - let mut value_col = value_col.clone(); - value_col.rename(name); - let tmp_df = DataFrame::new_no_checks(vec![value_col]); - let mut aggregated = expr.evaluate(&tmp_df, &groups)?; - aggregated.rename(value_col_name); - aggregated - } - }, + use PivotAgg::*; + let value_agg = unsafe { + match &agg_fn { + None => match value_col.len() > groups.len() { + true => polars_bail!(ComputeError: "found multiple elements in the same group, please specify an aggregation function"), + false => value_col.agg_first(&groups), } - }; - - let headers = column_agg.unique_stable()?.cast(&DataType::String)?; - let mut headers = headers.str().unwrap().clone(); - if values.len() > 1 { - headers = headers.apply_values(|v| Cow::from(format!("{value_col_name}{sep}{column_column_name}{sep}{v}"))) + Some(agg_fn) => match agg_fn { + Sum => value_col.agg_sum(&groups), + Min => value_col.agg_min(&groups), + Max => value_col.agg_max(&groups), + Last => value_col.agg_last(&groups), + First => value_col.agg_first(&groups), + Mean => value_col.agg_mean(&groups), + Median => value_col.agg_median(&groups), + Count => groups.group_count().into_series(), + Expr(ref expr) => { + let name = expr.root_name()?; + let mut value_col = value_col.clone(); + value_col.rename(name); + let tmp_df = value_col.into_frame(); + let mut aggregated = expr.evaluate(&tmp_df, &groups)?; + aggregated.rename(value_col_name); + aggregated + } + }, } + }; - let n_cols = headers.len(); - let value_agg_phys = value_agg.to_physical_repr(); - let logical_type = value_agg.dtype(); + let headers = column_agg.unique_stable()?.cast(&DataType::String)?; + let mut headers = headers.str().unwrap().clone(); + if values.len() > 1 { + // TODO! MILESTONE 1.0: change to `format!("{value_col_name}{sep}{v}")` + headers = headers.apply_values(|v| Cow::from(format!("{value_col_name}{sep}{column}{sep}{v}"))) + } - debug_assert_eq!(row_locations.len(), col_locations.len()); - debug_assert_eq!(value_agg_phys.len(), row_locations.len()); + let n_cols = headers.len(); + let value_agg_phys = value_agg.to_physical_repr(); + let logical_type = value_agg.dtype(); - let mut cols = if value_agg_phys.dtype().is_numeric() { - macro_rules! dispatch { - ($ca:expr) => {{ - positioning::position_aggregates_numeric( - n_rows, - n_cols, - &row_locations, - &col_locations, - $ca, - logical_type, - &headers, - ) - }}; - } - downcast_as_macro_arg_physical!(value_agg_phys, dispatch) - } else { - positioning::position_aggregates( - n_rows, - n_cols, - &row_locations, - &col_locations, - &value_agg_phys, - logical_type, - &headers, - ) - }; + debug_assert_eq!(row_locations.len(), col_locations.len()); + debug_assert_eq!(value_agg_phys.len(), row_locations.len()); - if sort_columns { - cols.sort_unstable_by(|a, b| a.name().partial_cmp(b.name()).unwrap()); + let mut cols = if value_agg_phys.dtype().is_numeric() { + macro_rules! dispatch { + ($ca:expr) => {{ + positioning::position_aggregates_numeric( + n_rows, + n_cols, + &row_locations, + &col_locations, + $ca, + logical_type, + &headers, + ) + }}; } + downcast_as_macro_arg_physical!(value_agg_phys, dispatch) + } else { + positioning::position_aggregates( + n_rows, + n_cols, + &row_locations, + &col_locations, + &value_agg_phys, + logical_type, + &headers, + ) + }; - let cols = if count == 0 { - let mut final_cols = row_index.take().unwrap(); - final_cols.extend(cols); - final_cols - } else { - cols - }; - count += 1; - final_cols.extend_from_slice(&cols); + if sort_columns { + cols.sort_unstable_by(|a, b| a.name().partial_cmp(b.name()).unwrap()); } + + let cols = if count == 0 { + let mut final_cols = row_index.take().unwrap(); + final_cols.extend(cols); + final_cols + } else { + cols + }; + count += 1; + final_cols.extend_from_slice(&cols); } Ok(()) }); out?; - Ok(DataFrame::new_no_checks(final_cols)) + + // SAFETY: length has already been checked. + unsafe { DataFrame::new_no_length_checks(final_cols) } } diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index 43450a7e91b5..5ad0b32f101d 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -1,7 +1,9 @@ use std::hash::Hash; +use arrow::legacy::trusted_len::TrustedLenPush; use polars_core::prelude::*; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -21,7 +23,7 @@ pub(super) fn position_aggregates( let split = _split_offsets(row_locations.len(), n_threads); // ensure the slice series are not dropped - // so the anyvalues are referencing correct data, if they reference arrays (struct) + // so the AnyValues are referencing correct data, if they reference arrays (struct) let n_splits = split.len(); let mut arrays: Vec = Vec::with_capacity(n_splits); @@ -43,7 +45,7 @@ pub(super) fn position_aggregates( .zip(col_locations) .zip(value_agg_phys.phys_iter()) { - // Safety: + // SAFETY: // in bounds unsafe { let idx = *row_idx as usize + *col_idx as usize * n_rows; @@ -115,7 +117,7 @@ where let split = _split_offsets(row_locations.len(), n_threads); let n_splits = split.len(); // ensure the arrays are not dropped - // so the anyvalues are referencing correct data, if they reference arrays (struct) + // so the AnyValues are referencing correct data, if they reference arrays (struct) let mut arrays: Vec> = Vec::with_capacity(n_splits); // every thread will only write to their partition @@ -138,7 +140,7 @@ where .zip(col_locations) .zip(value_agg_phys.into_iter()) { - // Safety: + // SAFETY: // in bounds unsafe { let idx = *row_idx as usize + *col_idx as usize * n_rows; @@ -174,21 +176,50 @@ where fn compute_col_idx_numeric(column_agg_physical: &ChunkedArray) -> Vec where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); let mut idx = 0 as IdxSize; - column_agg_physical - .into_iter() - .map(|v| { - let idx = *col_to_idx.entry(v).or_insert_with(|| { + let mut out = Vec::with_capacity(column_agg_physical.len()); + + for opt_v in column_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; + } + out +} + +fn compute_col_idx_gen<'a, T>(column_agg_physical: &'a ChunkedArray) -> Vec +where + T: PolarsDataType, + &'a T::Array: IntoIterator>>, + T::Physical<'a>: Hash + Eq, +{ + let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut idx = 0 as IdxSize; + let mut out = Vec::with_capacity(column_agg_physical.len()); + + for arr in column_agg_physical.downcast_iter() { + for opt_v in arr.into_iter() { + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { let old_idx = idx; idx += 1; old_idx }); - idx - }) - .collect() + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; + } + } + out } pub(super) fn compute_col_idx( @@ -202,14 +233,40 @@ pub(super) fn compute_col_idx( use DataType::*; let col_locations = match column_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = column_agg_physical.bit_repr_small(); compute_col_idx_numeric(&ca) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = column_agg_physical.bit_repr_large(); compute_col_idx_numeric(&ca) }, + Float64 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, + Float32 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, + Struct(_) => { + let ca = column_agg_physical.struct_().unwrap(); + let ca = ca.rows_encode()?; + compute_col_idx_gen(&ca) + }, + String => { + let ca = column_agg_physical.str().unwrap(); + let ca = ca.as_binary(); + compute_col_idx_gen(&ca) + }, + Binary => { + let ca = column_agg_physical.binary().unwrap(); + compute_col_idx_gen(ca) + }, + Boolean => { + let ca = column_agg_physical.bool().unwrap(); + compute_col_idx_gen(ca) + }, _ => { let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); let mut idx = 0 as IdxSize; @@ -230,37 +287,43 @@ pub(super) fn compute_col_idx( Ok((col_locations, column_agg)) } -fn compute_row_idx_numeric( +fn compute_row_index<'a, T>( index: &[String], - index_agg_physical: &ChunkedArray, + index_agg_physical: &'a ChunkedArray, count: usize, logical_type: &DataType, ) -> (Vec, usize, Option>) where - T: PolarsNumericType, - T::Native: Hash + Eq, + T: PolarsDataType, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, + ChunkedArray: FromIterator>>, ChunkedArray: IntoSeries, { let mut row_to_idx = PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); let mut idx = 0 as IdxSize; - let row_locations = index_agg_physical - .into_iter() - .map(|v| { - let idx = *row_to_idx.entry(v).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); - idx - }) - .collect::>(); + let mut row_locations = Vec::with_capacity(index_agg_physical.len()); + for opt_v in index_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); + } + } let row_index = match count { 0 => { let mut s = row_to_idx .into_iter() - .map(|(k, _)| k) + .map(|(k, _)| Option::>::peel_total_ord(k)) .collect::>() .into_series(); s.rename(&index[0]); @@ -273,6 +336,51 @@ where (row_locations, idx as usize, row_index) } +fn compute_row_index_struct( + index: &[String], + index_agg: &Series, + index_agg_physical: &BinaryOffsetChunked, + count: usize, +) -> (Vec, usize, Option>) { + let mut row_to_idx = + PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + let mut idx = 0 as IdxSize; + + let mut row_locations = Vec::with_capacity(index_agg_physical.len()); + let mut unique_indices = Vec::with_capacity(index_agg_physical.len()); + let mut row_number: IdxSize = 0; + for arr in index_agg_physical.downcast_iter() { + for opt_v in arr.iter() { + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + // SAFETY: we pre-allocated + unsafe { unique_indices.push_unchecked(row_number) }; + let old_idx = idx; + idx += 1; + old_idx + }); + row_number += 1; + + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); + } + } + } + let row_index = match count { + 0 => { + // SAFETY: `unique_indices` is filled with elements between + // 0 and `index_agg.len() - 1`. + let mut s = unsafe { index_agg.take_slice_unchecked(&unique_indices) }; + s.rename(&index[0]); + Some(vec![s]) + }, + _ => None, + }; + + (row_locations, idx as usize, row_index) +} + // TODO! Also create a specialized version for numerics. pub(super) fn compute_row_idx( pivot_df: &DataFrame, @@ -287,13 +395,34 @@ pub(super) fn compute_row_idx( use DataType::*; match index_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = index_agg_physical.bit_repr_small(); - compute_row_idx_numeric(index, &ca, count, index_s.dtype()) + compute_row_index(index, &ca, count, index_s.dtype()) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = index_agg_physical.bit_repr_large(); - compute_row_idx_numeric(index, &ca, count, index_s.dtype()) + compute_row_index(index, &ca, count, index_s.dtype()) + }, + Float64 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + Float32 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + Boolean => { + let ca = index_agg_physical.bool().unwrap(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + Struct(_) => { + let ca = index_agg_physical.struct_().unwrap(); + let ca = ca.rows_encode()?; + compute_row_index_struct(index, &index_agg, &ca, count) + }, + String => { + let ca = index_agg_physical.str().unwrap(); + compute_row_index(index, ca, count, index_s.dtype()) }, _ => { let mut row_to_idx = @@ -327,61 +456,23 @@ pub(super) fn compute_row_idx( }, } } else { - let index_s = pivot_df.columns(index)?; - let index_agg_physical = index_s - .iter() - .map(|s| unsafe { s.agg_first(groups).to_physical_repr().into_owned() }) - .collect::>(); - let mut iters = index_agg_physical - .iter() - .map(|s| s.phys_iter()) - .collect::>(); - let mut row_to_idx = - PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); - let mut idx = 0 as IdxSize; - - let mut row_locations = Vec::with_capacity(groups.len()); - loop { - match iters - .iter_mut() - .map(|it| it.next()) - .collect::>>() - { - None => break, - Some(items) => { - let idx = *row_to_idx.entry(items).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); - row_locations.push(idx) - }, - } - } - let row_index = match count { - 0 => Some( - index - .iter() - .enumerate() - .map(|(i, name)| { - let s = Series::new( - name, - row_to_idx - .iter() - .map(|(k, _)| { - debug_assert!(i < k.len()); - unsafe { k.get_unchecked(i).clone() } - }) - .collect::>(), - ); - restore_logical_type(&s, index_s[i].dtype()) - }) - .collect::>(), - ), - _ => None, - }; - - (row_locations, idx as usize, row_index) + let binding = pivot_df.select(index)?; + let fields = binding.get_columns(); + let index_struct_series = StructChunked::new("placeholder", fields)?.into_series(); + let index_agg = unsafe { index_struct_series.agg_first(groups) }; + let index_agg_physical = index_agg.to_physical_repr(); + let ca = index_agg_physical.struct_()?; + let ca = ca.rows_encode()?; + let (row_locations, n_rows, row_index) = + compute_row_index_struct(index, &index_agg, &ca, count); + let row_index = row_index.map(|x| { + unsafe { x.get_unchecked(0) } + .struct_() + .unwrap() + .fields() + .to_vec() + }); + (row_locations, n_rows, row_index) }; Ok((row_locations, n_rows, row_index)) diff --git a/crates/polars-ops/src/series/ops/approx_unique.rs b/crates/polars-ops/src/series/ops/approx_unique.rs index d812e4dcb34d..31093e06b77a 100644 --- a/crates/polars-ops/src/series/ops/approx_unique.rs +++ b/crates/polars-ops/src/series/ops/approx_unique.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; #[cfg(feature = "approx_unique")] use crate::series::ops::approx_algo::HyperLogLog; @@ -9,11 +10,11 @@ use crate::series::ops::approx_algo::HyperLogLog; fn approx_n_unique_ca<'a, T>(ca: &'a ChunkedArray) -> PolarsResult where T: PolarsDataType, - &'a ChunkedArray: IntoIterator, - <<&'a ChunkedArray as IntoIterator>::IntoIter as IntoIterator>::Item: Hash + Eq, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, { let mut hllp = HyperLogLog::new(); - ca.into_iter().for_each(|item| hllp.add(&item)); + ca.iter().for_each(|item| hllp.add(&item.to_total_ord())); let c = hllp.count() as IdxSize; Ok(Series::new(ca.name(), &[c])) @@ -26,12 +27,15 @@ fn dispatcher(s: &Series) -> PolarsResult { Boolean => s.bool().and_then(approx_n_unique_ca), Binary => s.binary().and_then(approx_n_unique_ca), String => { - let s = s.cast(&Binary).unwrap(); - let ca = s.binary().unwrap(); - approx_n_unique_ca(ca) + let ca = s.str().unwrap().as_binary(); + approx_n_unique_ca(&ca) }, - Float32 => approx_n_unique_ca(&s.bit_repr_small()), - Float64 => approx_n_unique_ca(&s.bit_repr_large()), + Float32 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), + Float64 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 7c5930f9bf39..563d9c96f430 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -148,7 +148,7 @@ pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { Some(first_set_bit(mask)) } else { let mut first_false_idx: Option = None; - ca.into_iter() + ca.iter() .enumerate() .find_map(|(idx, val)| match val { Some(true) => Some(idx), @@ -171,7 +171,7 @@ fn arg_min_bool(ca: &BooleanChunked) -> Option { Some(first_unset_bit(mask)) } else { let mut first_true_idx: Option = None; - ca.into_iter() + ca.iter() .enumerate() .find_map(|(idx, val)| match val { Some(false) => Some(idx), @@ -193,7 +193,7 @@ fn arg_min_str(ca: &StringChunked) -> Option { IsSorted::Ascending => ca.first_non_null(), IsSorted::Descending => ca.last_non_null(), IsSorted::Not => ca - .into_iter() + .iter() .enumerate() .flat_map(|(idx, val)| val.map(|val| (idx, val))) .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) @@ -209,7 +209,7 @@ fn arg_max_str(ca: &StringChunked) -> Option { IsSorted::Ascending => ca.last_non_null(), IsSorted::Descending => ca.first_non_null(), IsSorted::Not => ca - .into_iter() + .iter() .enumerate() .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) .map(|tpl| tpl.0), diff --git a/crates/polars-ops/src/series/ops/clip.rs b/crates/polars-ops/src/series/ops/clip.rs index 170e7961d6a2..917b2a24654d 100644 --- a/crates/polars-ops/src/series/ops/clip.rs +++ b/crates/polars-ops/src/series/ops/clip.rs @@ -3,74 +3,15 @@ use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise}; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; -fn clip_helper( - ca: &ChunkedArray, - min: &ChunkedArray, - max: &ChunkedArray, -) -> ChunkedArray -where - T: PolarsNumericType, - T::Native: PartialOrd, -{ - match (min.len(), max.len()) { - (1, 1) => match (min.get(0), max.get(0)) { - (Some(min), Some(max)) => { - ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max))) - }, - _ => ChunkedArray::::full_null(ca.name(), ca.len()), - }, - (1, _) => match min.get(0) { - Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) { - (Some(s), Some(max)) => Some(clamp(s, min, max)), - _ => None, - }), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), - }, - (_, 1) => match max.get(0) { - Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) { - (Some(s), Some(min)) => Some(clamp(s, min, max)), - _ => None, - }), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), - }, - _ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| { - match (opt_s, opt_min, opt_max) { - (Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)), - _ => None, - } - }), - } -} - -fn clip_min_max_helper( - ca: &ChunkedArray, - bound: &ChunkedArray, - op: F, -) -> ChunkedArray -where - T: PolarsNumericType, - T::Native: PartialOrd, - F: Fn(T::Native, T::Native) -> T::Native, -{ - match bound.len() { - 1 => match bound.get(0) { - Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))), - _ => ChunkedArray::::full_null(ca.name(), ca.len()), - }, - _ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { - (Some(s), Some(bound)) => Some(op(s, bound)), - _ => None, - }), - } -} - -/// Clamp underlying values to the `min` and `max` values. +/// Set values outside the given boundaries to the boundary value. pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { - polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + polars_ensure!( + s.dtype().to_physical().is_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); let original_type = s.dtype(); - // cast min & max to the dtype of s first. - let (min, max) = (min.cast(s.dtype())?, max.cast(s.dtype())?); + let (min, max) = (min.strict_cast(s.dtype())?, max.strict_cast(s.dtype())?); let (s, min, max) = ( s.to_physical_repr(), @@ -85,9 +26,9 @@ pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); let out = clip_helper(ca, min, max).into_series(); - if original_type.is_logical(){ + if original_type.is_logical() { out.cast(original_type) - }else{ + } else { Ok(out) } }) @@ -96,13 +37,15 @@ pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { } } -/// Clamp underlying values to the `max` value. +/// Set values above the given maximum to the maximum value. pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { - polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + polars_ensure!( + s.dtype().to_physical().is_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); let original_type = s.dtype(); - // cast max to the dtype of s first. - let max = max.cast(s.dtype())?; + let max = max.strict_cast(s.dtype())?; let (s, max) = (s.to_physical_repr(), max.to_physical_repr()); @@ -112,9 +55,9 @@ pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); let out = clip_min_max_helper(ca, max, clamp_max).into_series(); - if original_type.is_logical(){ + if original_type.is_logical() { out.cast(original_type) - }else{ + } else { Ok(out) } }) @@ -123,13 +66,15 @@ pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { } } -/// Clamp underlying values to the `min` value. +/// Set values below the given minimum to the minimum value. pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { - polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + polars_ensure!( + s.dtype().to_physical().is_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); let original_type = s.dtype(); - // cast min to the dtype of s first. - let min = min.cast(s.dtype())?; + let min = min.strict_cast(s.dtype())?; let (s, min) = (s.to_physical_repr(), min.to_physical_repr()); @@ -139,9 +84,9 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); let out = clip_min_max_helper(ca, min, clamp_min).into_series(); - if original_type.is_logical(){ + if original_type.is_logical() { out.cast(original_type) - }else{ + } else { Ok(out) } }) @@ -149,3 +94,64 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { dt => polars_bail!(opq = clippy_min, dt), } } + +fn clip_helper( + ca: &ChunkedArray, + min: &ChunkedArray, + max: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + match (min.len(), max.len()) { + (1, 1) => match (min.get(0), max.get(0)) { + (Some(min), Some(max)) => { + ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max))) + }, + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + (1, _) => match min.get(0) { + Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) { + (Some(s), Some(max)) => Some(clamp(s, min, max)), + _ => None, + }), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + (_, 1) => match max.get(0) { + Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) { + (Some(s), Some(min)) => Some(clamp(s, min, max)), + _ => None, + }), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + _ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| { + match (opt_s, opt_min, opt_max) { + (Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)), + _ => None, + } + }), + } +} + +fn clip_min_max_helper( + ca: &ChunkedArray, + bound: &ChunkedArray, + op: F, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, + F: Fn(T::Native, T::Native) -> T::Native, +{ + match bound.len() { + 1 => match bound.get(0) { + Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { + (Some(s), Some(bound)) => Some(op(s, bound)), + _ => None, + }), + } +} diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs index cd2089a0e8f4..db4d199e9abf 100644 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -1,7 +1,6 @@ -use std::iter::FromIterator; use std::ops::{Add, AddAssign, Mul}; -use num_traits::Bounded; +use num_traits::{Bounded, One, Zero}; use polars_core::prelude::*; use polars_core::utils::{CustomIterTools, NoNull}; use polars_core::with_match_physical_numeric_polars_type; @@ -36,37 +35,29 @@ where } } -fn det_sum(state: &mut Option, v: Option) -> Option> +fn det_sum(state: &mut T, v: Option) -> Option> where T: Copy + PartialOrd + AddAssign + Add, { - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner + v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) + match v { + Some(v) => { + *state += v; + Some(Some(*state)) }, - (_, None) => Some(None), + None => Some(None), } } -fn det_prod(state: &mut Option, v: Option) -> Option> +fn det_prod(state: &mut T, v: Option) -> Option> where T: Copy + PartialOrd + Mul, { - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner * v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) + match v { + Some(v) => { + *state = *state * v; + Some(Some(*state)) }, - (_, None) => Some(None), + None => Some(None), } } @@ -78,8 +69,8 @@ where let init = Bounded::min_value(); let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_max).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_max).collect_reversed(), + false => ca.iter().scan(init, det_max).collect_trusted(), + true => ca.iter().rev().scan(init, det_max).collect_reversed(), }; out.with_name(ca.name()) } @@ -91,8 +82,8 @@ where { let init = Bounded::max_value(); let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_min).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_min).collect_reversed(), + false => ca.iter().scan(init, det_min).collect_trusted(), + true => ca.iter().rev().scan(init, det_min).collect_reversed(), }; out.with_name(ca.name()) } @@ -102,10 +93,10 @@ where T: PolarsNumericType, ChunkedArray: FromIterator>, { - let init = None; + let init = T::Native::zero(); let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_sum).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_sum).collect_reversed(), + false => ca.iter().scan(init, det_sum).collect_trusted(), + true => ca.iter().rev().scan(init, det_sum).collect_reversed(), }; out.with_name(ca.name()) } @@ -115,10 +106,10 @@ where T: PolarsNumericType, ChunkedArray: FromIterator>, { - let init = None; + let init = T::Native::one(); let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_prod).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_prod).collect_reversed(), + false => ca.iter().scan(init, det_prod).collect_trusted(), + true => ca.iter().rev().scan(init, det_prod).collect_reversed(), }; out.with_name(ca.name()) } diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index df9ee97c4b32..b7e87a23d8a8 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -1,6 +1,3 @@ -use std::cmp::PartialOrd; -use std::iter::once; - use polars_core::prelude::*; fn map_cats( @@ -82,8 +79,8 @@ pub fn cut( polars_ensure!(ll.len() == sorted_breaks.len() + 1, ShapeMismatch: "Provide nbreaks + 1 labels"); ll }, - None => (once(&f64::NEG_INFINITY).chain(sorted_breaks.iter())) - .zip(sorted_breaks.iter().chain(once(&f64::INFINITY))) + None => (std::iter::once(&f64::NEG_INFINITY).chain(sorted_breaks.iter())) + .zip(sorted_breaks.iter().chain(std::iter::once(&f64::INFINITY))) .map(|v| { if left_closed { format!("[{}, {})", v.0, v.1) diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs index cbc16b6abc37..22b99a04a892 100644 --- a/crates/polars-ops/src/series/ops/ewm.rs +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -1,6 +1,3 @@ -use std::convert::TryFrom; - -use arrow::array::ArrayRef; pub use arrow::legacy::kernels::ewm::EWMOptions; use arrow::legacy::kernels::ewm::{ ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, diff --git a/crates/polars-ops/src/series/ops/floor_divide.rs b/crates/polars-ops/src/series/ops/floor_divide.rs index 4220ef5281d9..85d8750a5313 100644 --- a/crates/polars-ops/src/series/ops/floor_divide.rs +++ b/crates/polars-ops/src/series/ops/floor_divide.rs @@ -1,84 +1,21 @@ -use arrow::array::{Array, PrimitiveArray}; -use arrow::compute::utils::combine_validities_and; -use num::NumCast; -use polars_core::datatypes::PolarsNumericType; -use polars_core::export::num; +use polars_compute::arithmetic::ArithmeticKernel; +use polars_core::chunked_array::ops::arity::apply_binary_kernel_broadcast; use polars_core::prelude::*; #[cfg(feature = "dtype-struct")] use polars_core::series::arithmetic::_struct_arithmetic; use polars_core::with_match_physical_numeric_polars_type; -#[inline] -fn floor_div_element(a: T, b: T) -> T { - // Safety: the casts of those primitives always succeed - unsafe { - let a: f64 = NumCast::from(a).unwrap_unchecked(); - let b: f64 = NumCast::from(b).unwrap_unchecked(); - - let out = (a / b).floor(); - let out: T = NumCast::from(out).unwrap_unchecked(); - out - } -} - -fn floor_div_array( - a: &PrimitiveArray, - b: &PrimitiveArray, -) -> PrimitiveArray { - assert_eq!(a.len(), b.len()); - - if a.null_count() == 0 && b.null_count() == 0 { - let values = a - .values() - .as_slice() - .iter() - .copied() - .zip(b.values().as_slice().iter().copied()) - .map(|(a, b)| floor_div_element(a, b)) - .collect::>(); - - let validity = combine_validities_and(a.validity(), b.validity()); - - PrimitiveArray::new(a.data_type().clone(), values.into(), validity) - } else { - let iter = a - .into_iter() - .zip(b) - .map(|(opt_a, opt_b)| match (opt_a, opt_b) { - (Some(&a), Some(&b)) => Some(floor_div_element(a, b)), - _ => None, - }); - PrimitiveArray::from_trusted_len_iter(iter) - } -} - -fn floor_div_ca(a: &ChunkedArray, b: &ChunkedArray) -> ChunkedArray { - if a.len() == 1 { - let name = a.name(); - return if let Some(a) = a.get(0) { - let mut out = if b.null_count() == 0 { - b.apply_values(|b| floor_div_element(a, b)) - } else { - b.apply(|b| b.map(|b| floor_div_element(a, b))) - }; - out.rename(name); - out - } else { - ChunkedArray::full_null(a.name(), b.len()) - }; - } - if b.len() == 1 { - return if let Some(b) = b.get(0) { - if a.null_count() == 0 { - a.apply_values(|a| floor_div_element(a, b)) - } else { - a.apply(|a| a.map(|a| floor_div_element(a, b))) - } - } else { - ChunkedArray::full_null(a.name(), a.len()) - }; - } - arity::binary(a, b, floor_div_array) +fn floor_div_ca( + lhs: &ChunkedArray, + rhs: &ChunkedArray, +) -> ChunkedArray { + apply_binary_kernel_broadcast( + lhs, + rhs, + |l, r| ArithmeticKernel::wrapping_floor_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar(l.clone(), r), + ) } pub fn floor_div_series(a: &Series, b: &Series) -> PolarsResult { diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 1228f1d71bec..fd4dd76d2434 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -5,12 +5,6 @@ use polars_core::prelude::*; use polars_core::POOL; use rayon::prelude::*; -pub fn sum_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); - df.sum_horizontal(NullStrategy::Ignore) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) -} - pub fn any_horizontal(s: &[Series]) -> PolarsResult { let out = POOL .install(|| { @@ -48,17 +42,29 @@ pub fn all_horizontal(s: &[Series]) -> PolarsResult { } pub fn max_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.max_horizontal() .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } pub fn min_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.min_horizontal() .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } +pub fn sum_horizontal(s: &[Series]) -> PolarsResult> { + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; + df.sum_horizontal(NullStrategy::Ignore) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) +} + +pub fn mean_horizontal(s: &[Series]) -> PolarsResult> { + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; + df.mean_horizontal(NullStrategy::Ignore) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) +} + pub fn coalesce_series(s: &[Series]) -> PolarsResult { // TODO! this can be faster if we have more than two inputs. polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list"); diff --git a/crates/polars-ops/src/series/ops/is_first_distinct.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs index 178c80bb980d..b75ae23dba1f 100644 --- a/crates/polars-ops/src/series/ops/is_first_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -5,16 +5,18 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::bit_util::*; use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; -use polars_core::with_match_physical_integer_polars_type; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut unique = PlHashSet::new(); let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { arr.into_iter() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_trusted() }); @@ -126,16 +128,8 @@ pub fn is_first_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_first_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_first_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_first_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_first_distinct_numeric(ca) }) diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index fe5218df6dfc..27c7caf5dce9 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -1,3 +1,5 @@ +use std::hash::Hash; + #[cfg(feature = "dtype-categorical")] use polars_core::apply_amortized_generic_list_or_array; use polars_core::prelude::*; @@ -5,7 +7,7 @@ use polars_core::utils::{try_get_supertype, CustomIterTools}; use polars_core::with_match_physical_numeric_polars_type; #[cfg(feature = "dtype-categorical")] use polars_utils::iter::EnumerateIdxTrait; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_in_helper_ca<'a, T>( ca: &'a ChunkedArray, @@ -13,25 +15,27 @@ fn is_in_helper_ca<'a, T>( ) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: TotalHash + TotalEq + Copy, + T::Physical<'a>: TotalHash + TotalEq + ToTotalOrd + Copy, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let mut set = PlHashSet::with_capacity(other.len()); other.downcast_iter().for_each(|iter| { iter.iter().for_each(|opt_val| { if let Some(v) = opt_val { - set.insert(TotalOrdWrap(v)); + set.insert(v.to_total_ord()); } }) }); Ok(ca - .apply_values_generic(|val| set.contains(&TotalOrdWrap(val))) + .apply_values_generic(|val| set.contains(&val.to_total_ord())) .with_name(ca.name())) } fn is_in_helper<'a, T>(ca: &'a ChunkedArray, other: &Series) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: TotalHash + TotalEq + Copy, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let other = ca.unpack_series_matching_type(other)?; is_in_helper_ca(ca, other) @@ -49,7 +53,7 @@ where Some( opt_s.map(|s| { let ca = s.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true), ) }) @@ -58,12 +62,12 @@ where // SAFETY: unstable series never lives longer than the iterator. unsafe { ca_in - .into_iter() + .iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -87,19 +91,19 @@ where Some( opt_s.map(|s| { let ca = s.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true), ) }) } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); ca_in - .into_iter() + .iter() .zip(other.array()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -112,7 +116,8 @@ where fn is_in_numeric(ca_in: &ChunkedArray, other: &Series) -> PolarsResult where T: PolarsNumericType, - T::Native: TotalHash + TotalEq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Copy, { // We check implicitly cast to supertype here match other.dtype() { @@ -179,7 +184,7 @@ fn is_in_string_inner_categorical( if ca.null_count() == 0 { ca.into_no_null_iter().any(|a| a == idx) } else { - ca.into_iter().any(|a| a == Some(idx)) + ca.iter().any(|a| a == Some(idx)) } }) == Some(true), ) @@ -232,6 +237,10 @@ fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult { is_in_binary(&ca_in.as_binary(), &other.cast(&DataType::Binary).unwrap()) }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) | DataType::Categorical(_, _) => { + is_in_string_categorical(ca_in, other.categorical().unwrap()) + }, _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } } @@ -244,7 +253,7 @@ fn is_in_binary_list(ca_in: &BinaryChunked, other: &Series) -> PolarsResult().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true), ) }) @@ -253,12 +262,12 @@ fn is_in_binary_list(ca_in: &BinaryChunked, other: &Series) -> PolarsResult { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -278,19 +287,19 @@ fn is_in_binary_array(ca_in: &BinaryChunked, other: &Series) -> PolarsResult().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true), ) }) } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); ca_in - .into_iter() + .iter() .zip(other.array()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -322,7 +331,7 @@ fn is_in_boolean_list(ca_in: &BooleanChunked, other: &Series) -> PolarsResult().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true) }) .trust_my_length(other.len()) @@ -333,12 +342,12 @@ fn is_in_boolean_list(ca_in: &BooleanChunked, other: &Series) -> PolarsResult { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -361,7 +370,7 @@ fn is_in_boolean_array(ca_in: &BooleanChunked, other: &Series) -> PolarsResult().unwrap(); - ca.into_iter().any(|a| a == value) + ca.iter().any(|a| a == value) }) == Some(true) }) .trust_my_length(other.len()) @@ -370,12 +379,12 @@ fn is_in_boolean_array(ca_in: &BooleanChunked, other: &Series) -> PolarsResult { let ca = series.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -421,7 +430,7 @@ fn is_in_struct_list(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult { let ca = series.as_ref().struct_().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -459,19 +468,19 @@ fn is_in_struct_array(ca_in: &StructChunked, other: &Series) -> PolarsResult { let ca = series.as_ref().struct_().unwrap(); - ca.into_iter().any(|a| a == val) + ca.iter().any(|a| a == val) }, _ => false, }) @@ -521,17 +530,17 @@ fn is_in_struct(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult PolarsResult { + // In case of fast unique, we can directly use the categories. Otherwise we need to + // first get the unique physicals + let categories = StringChunked::with_chunk("", other.get_rev_map().get_categories().clone()); + let other = if other._can_fast_unique() { + categories + } else { + let s = other.physical().unique()?.cast(&IDX_DTYPE)?; + // SAFETY: Invariant of categorical means indices are in bound + unsafe { categories.take_unchecked(s.idx()?) } + }; + is_in_helper_ca(&ca_in.as_binary(), &other.as_binary()) +} + #[cfg(feature = "dtype-categorical")] fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult { match other.dtype() { @@ -575,12 +602,12 @@ fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult { for (global_idx, local_idx) in hash_map.iter() { - // Safety: index is in bounds + // SAFETY: index is in bounds if others .contains(unsafe { categories.value_unchecked(*local_idx as usize) }) { #[allow(clippy::unnecessary_cast)] - set.insert(TotalOrdWrap(*global_idx as u32)); + set.insert((*global_idx as u32).to_total_ord()); } } }, @@ -591,7 +618,7 @@ fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult { // fast path. @@ -31,16 +32,8 @@ pub fn is_last_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_last_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_last_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_last_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_last_distinct_numeric(ca) }) @@ -131,7 +124,8 @@ fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); @@ -139,7 +133,7 @@ where let mut new_ca: BooleanChunked = arr .into_iter() .rev() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_reversed::>() .into_inner(); new_ca.rename(ca.name()); diff --git a/crates/polars-ops/src/series/ops/is_unique.rs b/crates/polars-ops/src/series/ops/is_unique.rs index fee3703839eb..265e8736b35e 100644 --- a/crates/polars-ops/src/series/ops/is_unique.rs +++ b/crates/polars-ops/src/series/ops/is_unique.rs @@ -1,24 +1,26 @@ +use std::hash::Hash; + use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; // If invert is true then this is an `is_duplicated`. fn is_unique_ca<'a, T>(ca: &'a ChunkedArray, invert: bool) -> BooleanChunked where T: PolarsDataType, - &'a ChunkedArray: IntoIterator, - <<&'a ChunkedArray as IntoIterator>::IntoIter as IntoIterator>::Item: TotalHash + TotalEq, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, { let len = ca.len(); let mut idx_key = PlHashMap::new(); // Instead of group_tuples, which allocates a full Vec per group, we now // just toggle a boolean that's false if a group has multiple entries. - ca.into_iter().enumerate().for_each(|(idx, key)| { + ca.iter().enumerate().for_each(|(idx, key)| { idx_key - .entry(TotalOrdWrap(key)) + .entry(key.to_total_ord()) .and_modify(|v: &mut (IdxSize, bool)| v.1 = false) .or_insert((idx as IdxSize, true)); }); diff --git a/crates/polars-ops/src/series/ops/log.rs b/crates/polars-ops/src/series/ops/log.rs index 118b287f340e..1650afe2d7d4 100644 --- a/crates/polars-ops/src/series/ops/log.rs +++ b/crates/polars-ops/src/series/ops/log.rs @@ -74,6 +74,11 @@ pub trait LogSeries: SeriesSealed { /// where `pk` are discrete probabilities. fn entropy(&self, base: f64, normalize: bool) -> PolarsResult { let s = self.as_series().to_physical_repr(); + // if there is only one value in the series, return 0.0 to prevent the + // function from returning -0.0 + if s.len() == 1 { + return Ok(0.0); + } match s.dtype() { DataType::Float32 | DataType::Float64 => { let pk = s.as_ref(); diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 9b2482ca898d..8a64afbd9fbc 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -118,6 +118,8 @@ pub use to_dummies::*; #[cfg(feature = "unique_counts")] pub use unique::*; pub use various::*; +mod not; +pub use not::*; pub trait SeriesSealed { fn as_series(&self) -> &Series; diff --git a/crates/polars-ops/src/series/ops/not.rs b/crates/polars-ops/src/series/ops/not.rs new file mode 100644 index 000000000000..2bb153166254 --- /dev/null +++ b/crates/polars-ops/src/series/ops/not.rs @@ -0,0 +1,18 @@ +use std::ops::Not; + +use polars_core::with_match_physical_integer_polars_type; + +use super::*; + +pub fn negate_bitwise(s: &Series) -> PolarsResult { + match s.dtype() { + DataType::Boolean => Ok(s.bool().unwrap().not().into_series()), + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_any().downcast_ref().unwrap(); + Ok(ca.apply_values(|v| !v).into_series()) + }) + }, + dt => polars_bail!(InvalidOperation: "dtype {:?} not supported in 'not' operation", dt), + } +} diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index 8bcc3347fc66..dd2fe3936945 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -1,11 +1,7 @@ use arrow::array::BooleanArray; use arrow::compute::concatenate::concatenate_validities; use polars_core::prelude::*; -#[cfg(feature = "random")] -use rand::prelude::SliceRandom; use rand::prelude::*; -#[cfg(feature = "random")] -use rand::{rngs::SmallRng, SeedableRng}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index 752355b68242..e40a6aa3ec0a 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -2,11 +2,10 @@ use std::ops::BitOr; use polars_core::prelude::*; use polars_core::utils::try_get_supertype; -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_error::{polars_bail, polars_ensure}; use crate::frame::join::*; use crate::prelude::*; -use crate::series::is_in; pub fn replace( s: &Series, @@ -90,7 +89,7 @@ fn replace_by_multiple( ComputeError: "`new` input for `replace` must have the same length as `old` or have length 1" ); - let df = DataFrame::new_no_checks(vec![s.clone()]); + let df = s.clone().into_frame(); let replacer = create_replacer(old, new)?; let joined = df.join( @@ -133,6 +132,6 @@ fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult } else { vec![old, new] }; - let out = DataFrame::new_no_checks(cols); + let out = unsafe { DataFrame::new_no_checks(cols) }; Ok(out) } diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index 3d68161a304e..f2d8c4f3b70a 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -1,5 +1,3 @@ -use polars_core::frame::group_by::GroupsIndicator; - use super::*; #[cfg(feature = "dtype-u8")] @@ -22,7 +20,7 @@ impl ToDummies for Series { let col_name = self.name(); let groups = self.group_tuples(true, drop_first)?; - // safety: groups are in bounds + // SAFETY: groups are in bounds let columns = unsafe { self.agg_first(&groups) }; let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize); let columns = columns @@ -48,7 +46,7 @@ impl ToDummies for Series { }) .collect(); - Ok(DataFrame::new_no_checks(sort_columns(columns))) + Ok(unsafe { DataFrame::new_no_checks(sort_columns(columns)) }) } } diff --git a/crates/polars-ops/src/series/ops/unique.rs b/crates/polars-ops/src/series/ops/unique.rs index 7c4c44618154..3a2d9b5652fe 100644 --- a/crates/polars-ops/src/series/ops/unique.rs +++ b/crates/polars-ops/src/series/ops/unique.rs @@ -3,14 +3,18 @@ use std::hash::Hash; use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; use polars_core::utils::NoNull; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn unique_counts_helper(items: I) -> IdxCa where I: Iterator, - J: Hash + Eq, + J: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut map = PlIndexMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); for item in items { + let item = item.to_total_ord(); map.entry(item) .and_modify(|cnt| { *cnt += 1; @@ -24,13 +28,12 @@ where /// Returns a count of the unique values in the order of appearance. pub fn unique_counts(s: &Series) -> PolarsResult { if s.dtype().to_physical().is_numeric() { - if s.bit_repr_is_large() { - let ca = s.bit_repr_large(); - Ok(unique_counts_helper(ca.into_iter()).into_series()) - } else { - let ca = s.bit_repr_small(); - Ok(unique_counts_helper(ca.into_iter()).into_series()) - } + let s_physical = s.to_physical_repr(); + + with_match_physical_numeric_polars_type!(s_physical.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s_physical.as_ref().as_ref().as_ref(); + Ok(unique_counts_helper(ca.iter()).into_series()) + }) } else { match s.dtype() { DataType::String => { diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index ecc341ea020e..cad413816ced 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "hash")] -use polars_core::export::ahash; #[cfg(feature = "dtype-struct")] use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; use polars_core::prelude::*; @@ -21,7 +19,7 @@ pub trait SeriesMethods: SeriesSealed { let values = unsafe { s.agg_first(&groups) }; let counts = groups.group_lengths("count"); let cols = vec![values, counts.into_series()]; - let df = DataFrame::new_no_checks(cols); + let df = unsafe { DataFrame::new_no_checks(cols) }; if sort { df.sort(["count"], true, false) } else { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs index 13c01d9bca62..11a16351ea45 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs @@ -80,6 +80,11 @@ impl<'a, O: Offset> Pushable<&'a [u8]> for Binary { assert_eq!(value.len(), 0); self.extend_constant(additional) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } #[derive(Debug)] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs index ce0fda8fe3e3..da353d83aa73 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs @@ -57,6 +57,8 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { additional: usize, ) -> PolarsResult<()> { let (values, validity) = decoded; + let views_offset = values.views().len(); + let buffer_offset = values.completed_buffers().len(); let mut validate_utf8 = self.check_utf8.take(); match state { @@ -190,7 +192,7 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { } if validate_utf8 { - values.validate_utf8() + values.validate_utf8(buffer_offset, views_offset) } else { Ok(()) } @@ -278,7 +280,7 @@ pub(super) fn finish( .boxed()) }, PhysicalType::Utf8View => { - // Safety: we already checked utf8 + // SAFETY: we already checked utf8 unsafe { Ok(Utf8ViewArray::new_unchecked( data_type.clone(), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs index c1913fcfdb5d..5f2d1107cd49 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs @@ -11,8 +11,10 @@ use super::super::utils::{ FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, }; use super::super::{utils, PagesIter}; -use crate::parquet::deserialize::SliceFilteredIter; -use crate::parquet::encoding::Encoding; +use crate::parquet::deserialize::{ + HybridDecoderBitmapIter, HybridRleBooleanIter, SliceFilteredIter, +}; +use crate::parquet::encoding::{hybrid_rle, Encoding}; use crate::parquet::page::{split_buffer, DataPage, DictPage}; #[derive(Debug)] @@ -75,6 +77,10 @@ enum State<'a> { Required(Required<'a>), FilteredRequired(FilteredRequired<'a>), FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), + RleOptional( + OptionalPageValidity<'a>, + HybridRleBooleanIter<'a, HybridDecoderBitmapIter<'a>>, + ), } impl<'a> State<'a> { @@ -84,6 +90,7 @@ impl<'a> State<'a> { State::Required(page) => page.length - page.offset, State::FilteredRequired(page) => page.len(), State::FilteredOptional(optional, _) => optional.len(), + State::RleOptional(optional, _) => optional.len(), } } } @@ -129,6 +136,17 @@ impl<'a> Decoder<'a> for BooleanDecoder { (Encoding::Plain, false, true) => { Ok(State::FilteredRequired(FilteredRequired::try_new(page)?)) }, + (Encoding::Rle, true, false) => { + let optional = OptionalPageValidity::try_new(page)?; + let (_, _, values) = split_buffer(page)?; + // For boolean values the length is pre-pended. + let (_len_in_bytes, values) = values.split_at(4); + let iter = hybrid_rle::Decoder::new(values, 1); + let values = HybridDecoderBitmapIter::new(iter, page.num_values()); + let values = HybridRleBooleanIter::new(values); + + Ok(State::RleOptional(optional, values)) + }, _ => Err(utils::not_implemented(page)), } } @@ -175,6 +193,15 @@ impl<'a> Decoder<'a> for BooleanDecoder { page_values.0.by_ref(), ); }, + State::RleOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.map(|v| v.unwrap()), + ); + }, } Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs index 219b0a51ba2e..50a89442570a 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs @@ -55,4 +55,9 @@ impl<'a> Pushable<&'a [u8]> for FixedSizeBinary { fn len(&self) -> usize { self.values.len() / self.size } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index b62a016ad576..d478beeb1c01 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -1,12 +1,9 @@ use arrow::array::PrimitiveArray; -use arrow::datatypes::{ArrowDataType, Field}; use arrow::match_integer_type; use ethnum::I256; use polars_error::polars_bail; -use super::nested_utils::{InitNested, NestedArrayIter}; use super::*; -use crate::parquet::schema::types::PrimitiveType; /// Converts an iterator of arrays to a trait object returning trait objects #[inline] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs index 6e02a57e65fb..da512675a4d4 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs @@ -61,6 +61,7 @@ mod tests { use super::iter_to_arrays; use crate::parquet::encoding::Encoding; use crate::parquet::error::Error as ParquetError; + #[allow(unused_imports)] use crate::parquet::fallible_streaming_iterator; use crate::parquet::metadata::Descriptor; use crate::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs index af53485ff93f..1a1759252ab8 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs @@ -6,7 +6,7 @@ use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; use polars_error::PolarsResult; -use super::super::dictionary::{nested_next_dict, *}; +use super::super::dictionary::*; use super::super::nested_utils::{InitNested, NestedState}; use super::super::utils::MaybeNext; use super::super::PagesIter; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils.rs index f41f98b46bf1..6919dd88dd74 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils.rs @@ -227,7 +227,7 @@ impl<'a> PageValidity<'a> for OptionalPageValidity<'a> { } } -fn reserve_pushable_and_validity<'a, T: Default, P: Pushable>( +fn reserve_pushable_and_validity<'a, T, P: Pushable>( validity: &mut MutableBitmap, page_validity: &'a mut dyn PageValidity, limit: Option, @@ -263,7 +263,7 @@ fn reserve_pushable_and_validity<'a, T: Default, P: Pushable>( } /// Extends a [`Pushable`] from an iterator of non-null values and an hybrid-rle decoder -pub(super) fn extend_from_decoder, I: Iterator>( +pub(super) fn extend_from_decoder, I: Iterator>( validity: &mut MutableBitmap, page_validity: &mut dyn PageValidity, limit: Option, @@ -300,7 +300,7 @@ pub(super) fn extend_from_decoder, I: Iterator for _ in values_iter.by_ref().take(valids) {}, diff --git a/crates/polars-parquet/src/arrow/read/indexes/mod.rs b/crates/polars-parquet/src/arrow/read/indexes/mod.rs index 0a8184bc2723..a9a48a98e8f4 100644 --- a/crates/polars-parquet/src/arrow/read/indexes/mod.rs +++ b/crates/polars-parquet/src/arrow/read/indexes/mod.rs @@ -184,7 +184,9 @@ fn deserialize( PhysicalType::Binary | PhysicalType::LargeBinary | PhysicalType::Utf8 - | PhysicalType::LargeUtf8 => { + | PhysicalType::LargeUtf8 + | PhysicalType::Utf8View + | PhysicalType::BinaryView => { let index = indexes .pop_front() .unwrap() diff --git a/crates/polars-parquet/src/arrow/read/schema/convert.rs b/crates/polars-parquet/src/arrow/read/schema/convert.rs index 5eeaa94a1355..d4d4ef15f19b 100644 --- a/crates/polars-parquet/src/arrow/read/schema/convert.rs +++ b/crates/polars-parquet/src/arrow/read/schema/convert.rs @@ -407,7 +407,6 @@ pub(crate) fn to_data_type( #[cfg(test)] mod tests { - use arrow::datatypes::{ArrowDataType, Field, TimeUnit}; use polars_error::*; use super::*; diff --git a/crates/polars-parquet/src/arrow/write/pages.rs b/crates/polars-parquet/src/arrow/write/pages.rs index 95b6e91d2b3d..516e7ebd4754 100644 --- a/crates/polars-parquet/src/arrow/write/pages.rs +++ b/crates/polars-parquet/src/arrow/write/pages.rs @@ -262,10 +262,9 @@ pub fn array_to_columns + Send + Sync>( #[cfg(test)] mod tests { use arrow::array::*; - use arrow::bitmap::Bitmap; use arrow::datatypes::*; - use super::super::{FieldInfo, ParquetPhysicalType, ParquetPrimitiveType}; + use super::super::{FieldInfo, ParquetPhysicalType}; use super::*; use crate::parquet::schema::types::{ GroupLogicalType, PrimitiveConvertedType, PrimitiveLogicalType, diff --git a/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs index 576f4d5f1aba..e8672648cda4 100644 --- a/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs +++ b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - /// magic numbers taken from https://github.com/apache/parquet-format/blob/master/BloomFilter.md const SALT: [u32; 8] = [ 1203114875, 1150766481, 2284105051, 2729912477, 1884591559, 770785867, 2667333959, 1550580529, diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs index 6d1658cd360b..3e638eeb05c7 100644 --- a/crates/polars-parquet/src/parquet/compression.rs +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -95,13 +95,19 @@ pub fn compress( )), #[cfg(feature = "zstd")] CompressionOptions::Zstd(level) => { - use std::io::Write; let level = level.map(|v| v.compression_level()).unwrap_or_default(); - - let mut encoder = zstd::Encoder::new(output_buf, level)?; - encoder.write_all(input_buf)?; - match encoder.finish() { - Ok(_) => Ok(()), + // Make sure the buffer is large enough; the interface assumption is + // that decompressed data is appended to the output buffer. + let old_len = output_buf.len(); + output_buf.resize( + old_len + zstd::zstd_safe::compress_bound(input_buf.len()), + 0, + ); + match zstd::bulk::compress_to_buffer(input_buf, &mut output_buf[old_len..], level) { + Ok(written_size) => { + output_buf.truncate(old_len + written_size); + Ok(()) + }, Err(e) => Err(e.into()), } }, @@ -334,13 +340,6 @@ mod tests { ))); } - #[test] - fn test_codec_gzip_high_compression() { - test_codec(CompressionOptions::Gzip(Some( - GzipLevel::try_new(10).unwrap(), - ))); - } - #[test] fn test_codec_brotli_default() { test_codec(CompressionOptions::Brotli(None)); diff --git a/crates/polars-parquet/src/parquet/deserialize/utils.rs b/crates/polars-parquet/src/parquet/deserialize/utils.rs index 0c89d09d4648..eaef2f9b6fd1 100644 --- a/crates/polars-parquet/src/parquet/deserialize/utils.rs +++ b/crates/polars-parquet/src/parquet/deserialize/utils.rs @@ -146,8 +146,6 @@ impl> Iterator for SliceFilteredIter { #[cfg(test)] mod test { - use std::collections::VecDeque; - use super::*; #[test] diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs index 904ff796dd34..cc4e62ebdd33 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use super::{Packed, Unpackable, Unpacked}; /// Encodes (packs) a slice of [`Unpackable`] into bitpacked bytes `packed`, using `num_bits` per value. diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs index da707d0cf933..d000b918efb7 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs @@ -1,3 +1,5 @@ +use polars_utils::slice::GetSaferUnchecked; + use super::super::{ceil8, uleb128}; use super::HybridEncoded; use crate::parquet::error::Error; @@ -39,7 +41,7 @@ impl<'a> Iterator for Decoder<'a> { Ok((indicator, consumed)) => (indicator, consumed), Err(e) => return Some(Err(e)), }; - self.values = &self.values[consumed..]; + self.values = unsafe { self.values.get_unchecked_release(consumed..) }; if self.values.is_empty() { return None; }; diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs index e2e381bc4b90..1c4dd67ccec7 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs @@ -40,7 +40,8 @@ fn bitpacked_encode_u32>( let remainder = length - chunks * U32_BLOCK_LEN; let mut buffer = [0u32; U32_BLOCK_LEN]; - let compressed_chunk_size = ceil8(U32_BLOCK_LEN * num_bits); + // simplified from ceil8(U32_BLOCK_LEN * num_bits) since U32_BLOCK_LEN = 32 + let compressed_chunk_size = 4 * num_bits; for _ in 0..chunks { iterator @@ -58,6 +59,9 @@ fn bitpacked_encode_u32>( // Must be careful here to ensure we write a multiple of `num_bits` // (the bit width) to align with the spec. Some readers also rely on // this - see https://github.com/pola-rs/polars/pull/13883. + + // this is ceil8(remainder * num_bits), but we ensure the output is a + // multiple of num_bits by rewriting it as ceil8(remainder) * num_bits let compressed_remainder_size = ceil8(remainder) * num_bits; iterator .by_ref() diff --git a/crates/polars-parquet/src/parquet/encoding/mod.rs b/crates/polars-parquet/src/parquet/encoding/mod.rs index 79b608ab63b7..81d751a3e004 100644 --- a/crates/polars-parquet/src/parquet/encoding/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - pub mod bitpacked; pub mod delta_bitpacked; pub mod delta_byte_array; diff --git a/crates/polars-parquet/src/parquet/parquet_bridge.rs b/crates/polars-parquet/src/parquet/parquet_bridge.rs index eec75e4994ca..e3851d211be8 100644 --- a/crates/polars-parquet/src/parquet/parquet_bridge.rs +++ b/crates/polars-parquet/src/parquet/parquet_bridge.rs @@ -1,5 +1,4 @@ // Bridges structs from thrift-generated code to rust enums. -use std::convert::TryFrom; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs index fbe2ef938f82..3366a5c56c66 100644 --- a/crates/polars-parquet/src/parquet/read/compression.rs +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -1,5 +1,4 @@ use parquet_format_safe::DataPageHeaderV2; -use streaming_decompression; use super::page::PageIterator; use crate::parquet::compression::{self, Compression}; @@ -246,7 +245,7 @@ impl streaming_decompression::Decompressed for Page { /// A [`FallibleStreamingIterator`] that decompresses [`CompressedPage`] into [`DataPage`]. /// # Implementation /// This decompressor uses an internal [`Vec`] to perform decompressions which -/// is re-used across pages, so that a single allocation is required. +/// is reused across pages, so that a single allocation is required. /// If the pages are not compressed, the internal buffer is not used. pub struct BasicDecompressor>> { iter: _Decompressor, diff --git a/crates/polars-parquet/src/parquet/read/indexes/read.rs b/crates/polars-parquet/src/parquet/read/indexes/read.rs index 379fb4150766..9572ccf17723 100644 --- a/crates/polars-parquet/src/parquet/read/indexes/read.rs +++ b/crates/polars-parquet/src/parquet/read/indexes/read.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::{Cursor, Read, Seek, SeekFrom}; use parquet_format_safe::thrift::protocol::TCompactInputProtocol; diff --git a/crates/polars-parquet/src/parquet/read/metadata.rs b/crates/polars-parquet/src/parquet/read/metadata.rs index a75b939a513c..10864e194aeb 100644 --- a/crates/polars-parquet/src/parquet/read/metadata.rs +++ b/crates/polars-parquet/src/parquet/read/metadata.rs @@ -1,5 +1,4 @@ use std::cmp::min; -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use parquet_format_safe::thrift::protocol::TCompactInputProtocol; diff --git a/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs b/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs index e72bc5de82e1..8a37566e1456 100644 --- a/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs @@ -30,7 +30,7 @@ pub struct IndexedPageReader { // buffer to read the whole page [header][data] into memory buffer: Vec, - // buffer to store the data [data] and re-use across pages + // buffer to store the data [data] and reuse across pages data_buffer: Vec, pages: VecDeque, diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index e0078f97c6d4..0f1c7d0fb0f3 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::Read; use std::sync::Arc; diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs index 10f3f0614dce..a27a0b9a57a8 100644 --- a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -619,10 +619,10 @@ impl<'a> Parser<'a> { #[cfg(test)] mod tests { - use types::{IntegerType, PrimitiveLogicalType}; + use types::IntegerType; use super::*; - use crate::parquet::schema::types::{GroupConvertedType, PhysicalType, PrimitiveConvertedType}; + use crate::parquet::schema::types::PhysicalType; #[test] fn test_tokenize_empty_string() { diff --git a/crates/polars-parquet/src/parquet/types.rs b/crates/polars-parquet/src/parquet/types.rs index f2e7b1472eb3..b9d93a91bd26 100644 --- a/crates/polars-parquet/src/parquet/types.rs +++ b/crates/polars-parquet/src/parquet/types.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use crate::parquet::schema::types::PhysicalType; /// A physical native representation of a Parquet fixed-sized type. diff --git a/crates/polars-parquet/src/parquet/write/page.rs b/crates/polars-parquet/src/parquet/write/page.rs index 1f024b629f07..ad6bc32efc68 100644 --- a/crates/polars-parquet/src/parquet/write/page.rs +++ b/crates/polars-parquet/src/parquet/write/page.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::Write; use std::sync::Arc; diff --git a/crates/polars-parquet/tests/it/main.rs b/crates/polars-parquet/tests/it/main.rs deleted file mode 100644 index e108117e793a..000000000000 --- a/crates/polars-parquet/tests/it/main.rs +++ /dev/null @@ -1 +0,0 @@ -mod roundtrip; diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml index fc17fdd21a7f..7c9615ba9d09 100644 --- a/crates/polars-pipe/Cargo.toml +++ b/crates/polars-pipe/Cargo.toml @@ -12,13 +12,14 @@ description = "Lazy query engine for the Polars DataFrame library" arrow = { workspace = true } futures = { workspace = true, optional = true } polars-compute = { workspace = true } -polars-core = { workspace = true, features = ["lazy", "zip_with", "random", "rows", "chunked_ids"] } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random", "rows"] } polars-io = { workspace = true, features = ["ipc"] } -polars-ops = { workspace = true, features = ["search_sorted"] } +polars-ops = { workspace = true, features = ["search_sorted", "chunked_ids"] } polars-plan = { workspace = true } polars-row = { workspace = true } polars-utils = { workspace = true, features = ["sysinfo"] } tokio = { workspace = true, optional = true } +uuid = { workspace = true } crossbeam-channel = { workspace = true } crossbeam-queue = { workspace = true } @@ -37,7 +38,7 @@ cloud = ["async", "polars-io/cloud", "polars-plan/cloud", "tokio", "futures"] parquet = ["polars-plan/parquet", "polars-io/parquet", "polars-io/async"] ipc = ["polars-plan/ipc", "polars-io/ipc"] json = ["polars-plan/json", "polars-io/json"] -async = ["polars-plan/async", "polars-io/async"] +async = ["polars-plan/async", "polars-io/async", "futures"] nightly = ["polars-core/nightly", "polars-utils/nightly", "hashbrown/nightly"] cross_join = ["polars-ops/cross_join"] dtype-u8 = ["polars-core/dtype-u8"] @@ -48,4 +49,3 @@ dtype-decimal = ["polars-core/dtype-decimal"] dtype-array = ["polars-core/dtype-array"] dtype-categorical = ["polars-core/dtype-categorical"] trigger_ooc = [] -test = ["polars-core/chunked_ids"] diff --git a/crates/polars-pipe/src/executors/operators/pass.rs b/crates/polars-pipe/src/executors/operators/pass.rs index c80e0fbe6104..6b2189dbc2b8 100644 --- a/crates/polars-pipe/src/executors/operators/pass.rs +++ b/crates/polars-pipe/src/executors/operators/pass.rs @@ -27,6 +27,6 @@ impl Operator for Pass { } fn fmt(&self) -> &str { - "pass" + self.name } } diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs index 1501da49d5fa..f1271457417c 100644 --- a/crates/polars-pipe/src/executors/operators/projection.rs +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -98,7 +98,7 @@ impl Operator for ProjectionOperator { } } - let chunk = chunk.with_data(DataFrame::new_no_checks(projected)); + let chunk = chunk.with_data(unsafe { DataFrame::new_no_checks(projected) }); Ok(OperatorResult::Finished(chunk)) } fn split(&self, _thread_no: usize) -> Box { @@ -149,7 +149,8 @@ impl Operator for HstackOperator { .map(|e| e.evaluate(chunk, context.execution_state.as_any())) .collect::>>()?; - let mut df = DataFrame::new_no_checks(chunk.data.get_columns()[..width].to_vec()); + let columns = chunk.data.get_columns()[..width].to_vec(); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; let schema = &*self.input_schema; if self.unchecked { diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs index 338375969636..c48f19f84b2d 100644 --- a/crates/polars-pipe/src/executors/operators/reproject.rs +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -45,7 +45,7 @@ pub(crate) fn reproject_chunk( } else { let columns = chunk.data.get_columns(); let cols = positions.iter().map(|i| columns[*i].clone()).collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } }; *chunk = chunk.with_data(out); Ok(()) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs index 8e07ca2aa4d2..82afe6a0b40c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs @@ -11,7 +11,6 @@ use polars_core::utils::arrow::compute::aggregate::sum_primitive; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub struct MeanAgg { sum: Option, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs index 8466031b6114..341bb067635b 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs @@ -2,14 +2,12 @@ use std::any::Any; use arrow::array::PrimitiveArray; use polars_compute::min_max::MinMaxKernel; -use polars_core::datatypes::{AnyValue, DataType}; use polars_core::export::num::NumCast; use polars_core::prelude::*; use polars_utils::min_max::MinMax; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub(super) fn new_min() -> MinMaxAgg K> { MinMaxAgg::new(MinMax::min_ignore_nan, true) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs index 8f2fdf963638..b256ca41720f 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs @@ -10,7 +10,6 @@ use polars_core::utils::arrow::compute::aggregate::sum_primitive; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub struct SumAgg { sum: Option, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs index cf00212422cc..3fa0b384dd0c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs @@ -1,6 +1,5 @@ use std::cell::UnsafeCell; -use arrow::array::{ArrayRef, BinaryArray}; use polars_core::export::ahash::RandomState; use polars_row::{RowsEncoded, SortField}; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs index a65e31c8b30d..4488a6faad82 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs @@ -1,6 +1,5 @@ use std::collections::LinkedList; use std::sync::atomic::{AtomicU16, Ordering}; -use std::sync::Mutex; use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_core::POOL; @@ -146,7 +145,7 @@ impl GlobalTable { let hash = *hashes.get_unchecked(i); let chunk_index = *chunk_indexes.get_unchecked(i); - // safety: keys_iters and cols_iters are not depleted + // SAFETY: keys_iters and cols_iters are not depleted let overflow = hash_map.insert(hash, row, &mut agg_cols_iters, chunk_index); // should never overflow debug_assert!(!overflow); diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs index db73a07aa990..a8c9683d1bd1 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -178,11 +178,11 @@ impl AggHashTable { let spill_size = self.spill_size; self.spill_size = usize::MAX; for (key_other, agg_idx_other) in other.inner_map.iter() { - // safety: idx is from the hashmap, so is in bounds + // SAFETY: idx is from the hashmap, so is in bounds let row = unsafe { other.get_keys_row(key_other) }; if on_condition(key_other.hash) { - // safety: will not overflow as we set it to usize::MAX; + // SAFETY: will not overflow as we set it to usize::MAX; let agg_idx_self = unsafe { self.insert_key(key_other.hash, row) .unwrap_unchecked_release() @@ -251,7 +251,7 @@ impl AggHashTable { unsafe { let running_agg = running_aggregations.get_unchecked_release_mut(i); let av = running_agg.finalize(); - // safety: finalize creates owned anyvalues + // SAFETY: finalize creates owned AnyValues buffer.add_unchecked_owned_physical(&av); } } @@ -275,7 +275,7 @@ impl AggHashTable { ); cols.extend(agg_builders.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs index 1ceb2e22aa57..41967ee85426 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs @@ -10,7 +10,7 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Mutex; -use arrow::array::{ArrayRef, BinaryArray}; +use arrow::array::BinaryArray; use eval::Eval; use hash_table::AggHashTable; use hashbrown::hash_map::{RawEntryMut, RawVacantEntryMut}; @@ -63,7 +63,7 @@ impl SpillPayload { let mut schema = Schema::with_capacity(self.aggs.len() + 2); schema.with_column(HASH_COL.into(), DataType::UInt64); schema.with_column(INDEX_COL.into(), IDX_DTYPE); - schema.with_column(KEYS_COL.into(), DataType::Binary); + schema.with_column(KEYS_COL.into(), DataType::BinaryOffset); for s in &self.aggs { schema.with_column(s.name().into(), s.dtype().clone()); } @@ -76,14 +76,14 @@ impl SpillPayload { let hashes = UInt64Chunked::from_vec(HASH_COL, self.hashes).into_series(); let chunk_idx = IdxCa::from_vec(INDEX_COL, self.chunk_idx).into_series(); - let keys = Series::try_from((KEYS_COL, Box::new(self.keys) as ArrayRef)).unwrap(); + let keys = BinaryOffsetChunked::with_chunk(KEYS_COL, self.keys).into_series(); let mut cols = Vec::with_capacity(self.aggs.len() + 3); cols.push(hashes); cols.push(chunk_idx); cols.push(keys); cols.extend(self.aggs); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } fn spilled_to_columns( diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs index 1b8610c54251..77a939c64290 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs @@ -1,7 +1,6 @@ use polars_core::config::verbose; use super::*; -use crate::executors::sinks::io::IOThread; use crate::executors::sinks::memory::MemTracker; use crate::pipeline::{morsels_per_sink, FORCE_OOC}; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs index e6588897290b..dd3231d5af7d 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs @@ -64,17 +64,17 @@ impl Sink for GenericGroupby2 { } // load data and hashes unsafe { - // safety: we don't hold mutable refs + // SAFETY: we don't hold mutable refs self.eval.evaluate_keys_aggs_and_hashes(context, &chunk)?; } - // safety: eval is alive for the duration of keys + // SAFETY: eval is alive for the duration of keys let keys = unsafe { self.eval.get_keys_iter() }; - // safety: we don't hold mutable refs + // SAFETY: we don't hold mutable refs let mut aggs = unsafe { self.eval.get_aggs_iters() }; let chunk_idx = chunk.chunk_index; unsafe { - // safety: the mutable borrows are not aliasing + // SAFETY: the mutable borrows are not aliasing let table = &mut *self.thread_local_table.get(); for (hash, row) in self.eval.hashes().iter().zip(keys.values_iter()) { @@ -89,7 +89,7 @@ impl Sink for GenericGroupby2 { // clear memory unsafe { drop(aggs); - // safety: we don't hold mutable refs, we just dropped them + // SAFETY: we don't hold mutable refs, we just dropped them self.eval.clear() }; @@ -122,7 +122,7 @@ impl Sink for GenericGroupby2 { } fn split(&self, _thread_no: usize) -> Box { - // safety: no mutable refs at this point + // SAFETY: no mutable refs at this point let map = unsafe { (*self.thread_local_table.get()).split() }; Box::new(Self { eval: self.eval.split(), diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs index bdb52235b3b7..4ebaf073525b 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs @@ -4,7 +4,7 @@ use polars_io::SerReader; use super::*; use crate::executors::sinks::group_by::generic::global::GlobalTable; -use crate::executors::sinks::io::{block_thread_until_io_thread_done, IOThread}; +use crate::executors::sinks::io::block_thread_until_io_thread_done; use crate::operators::{Source, SourceResult}; use crate::pipeline::PARTITION_SIZE; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs index 867f0a81353c..c2eaafe39d76 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs @@ -23,7 +23,7 @@ pub(super) fn physical_agg_to_logical(cols: &mut [Series], output_schema: &Schem dt @ (DataType::Categorical(rev_map, ordering) | DataType::Enum(rev_map, ordering)) => { if let Some(rev_map) = rev_map { let cats = s.u32().unwrap().clone(); - // safety: + // SAFETY: // the rev-map comes from these categoricals unsafe { *s = CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -37,7 +37,7 @@ pub(super) fn physical_agg_to_logical(cols: &mut [Series], output_schema: &Schem } else { let cats = s.u32().unwrap().clone(); if using_string_cache() { - // Safety, we go from logical to primitive back to logical so the categoricals should still match the global map. + // SAFETY, we go from logical to primitive back to logical so the categoricals should still match the global map. *s = unsafe { CategoricalChunked::from_global_indices_unchecked(cats, *ordering) .into_series() diff --git a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs index 9f78ec517289..30fb437bd6bd 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs @@ -168,7 +168,7 @@ where if agg_map.is_empty() { return None; } - // safety: + // SAFETY: // we will not alias. let ptr = aggregators as *mut AggregateFunction; let agg_fns = @@ -209,7 +209,7 @@ where cols.push(key_builder.finish().into_series()); cols.extend(buffers.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(DataFrame::new_no_checks(cols)) + Some(unsafe { DataFrame::new_no_checks(cols) }) }) .collect::>(); Ok(dfs) @@ -327,7 +327,7 @@ where processed += 1; } else { // set this row to true: e.g. processed ooc - // safety: we correctly set the length with `reset_ooc_filter_rows` + // SAFETY: we correctly set the length with `reset_ooc_filter_rows` unsafe { self.ooc_state.set_row_as_ooc(iteration_idx); } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/string.rs b/crates/polars-pipe/src/executors/sinks/group_by/string.rs index 289ca64dcb63..9d8bbf6e5547 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/string.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/string.rs @@ -166,7 +166,7 @@ impl StringGroupbySink { .zip(slices.par_iter()) .filter_map(|(agg_map, slice)| { let ptr = aggregators as *mut AggregateFunction; - // safety: + // SAFETY: // we will not alias. let aggregators = unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) }; @@ -213,7 +213,7 @@ impl StringGroupbySink { cols.push(key_builder.finish().into_series()); cols.extend(buffers.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(DataFrame::new_no_checks(cols)) + Some(unsafe { DataFrame::new_no_checks(cols) }) }) .collect::>(); @@ -284,7 +284,7 @@ impl StringGroupbySink { match entry { RawEntryMut::Vacant(_) => { // set this row to true: e.g. processed ooc - // safety: we correctly set the length with `reset_ooc_filter_rows` + // SAFETY: we correctly set the length with `reset_ooc_filter_rows` unsafe { self.ooc_state.set_row_as_ooc(iteration_idx); } @@ -426,7 +426,7 @@ impl Sink for StringGroupbySink { // the offset in the keys of self let idx_self = k_self.idx as usize; // slice to the keys of self - // safety: + // SAFETY: // in bounds let key_self = unsafe { self.keys.get_unchecked_release(idx_self) }; // compare the keys diff --git a/crates/polars-pipe/src/executors/sinks/io.rs b/crates/polars-pipe/src/executors/sinks/io.rs index 096326d10c0a..6d135c30e8d4 100644 --- a/crates/polars-pipe/src/executors/sinks/io.rs +++ b/crates/polars-pipe/src/executors/sinks/io.rs @@ -2,7 +2,6 @@ use std::fs; use std::fs::File; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use std::time::{Duration, SystemTime}; use crossbeam_channel::{bounded, Sender}; @@ -37,13 +36,10 @@ fn get_lockfile_path(dir: &Path) -> PathBuf { } fn get_spill_dir(operation_name: &'static str) -> PolarsResult { - let uuid = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_nanos(); + let id = uuid::Uuid::new_v4(); let mut dir = std::path::PathBuf::from(get_base_temp_dir()); - dir.push(&format!("polars/{operation_name}/{uuid}")); + dir.push(&format!("polars/{operation_name}/{id}")); if !dir.exists() { fs::create_dir_all(&dir).map_err(|err| { @@ -61,6 +57,19 @@ fn get_spill_dir(operation_name: &'static str) -> PolarsResult { Ok(dir) } +fn clean_after_delay(time: Option, secs: u64, path: &Path) { + if let Some(time) = time { + let modified_since = SystemTime::now().duration_since(time).unwrap().as_secs(); + if modified_since > secs { + // This can be fallible if another thread removes this. + // That is fine. + let _ = std::fs::remove_dir_all(path); + } + } else { + polars_warn!("could not modified time on this platform") + } +} + /// Starts a new thread that will clean up operations of directories that don't /// have a lockfile (opened with 'w' permissions). fn gc_thread(operation_name: &'static str) { @@ -81,20 +90,21 @@ fn gc_thread(operation_name: &'static str) { if let Ok(lockfile) = File::open(lockfile_path) { // lockfile can be read - if let Ok(time) = lockfile.metadata().unwrap().modified() { - let modified_since = - SystemTime::now().duration_since(time).unwrap().as_secs(); - // the lockfile can still exist if a process was canceled + if let Ok(md) = lockfile.metadata() { + let time = md.modified().ok(); + // The lockfile can still exist if a process was canceled // so we also check the modified date - // we don't expect queries that run a month - if modified_since > (SECONDS_IN_DAY as u64 * 30) { - std::fs::remove_dir_all(path).unwrap() - } - } else { - eprintln!("could not modified time on this platform") + // we don't expect queries that run a month. + clean_after_delay(time, SECONDS_IN_DAY as u64 * 30, &path); } } else { - std::fs::remove_dir_all(path).unwrap() + // If path already removed, we simply continue. + if let Ok(md) = path.metadata() { + let time = md.modified().ok(); + // Wait 15 seconds to ensure we don't remove before lockfile is created + // in a `collect_all` contention case + clean_after_delay(time, 15, &path); + } } } } @@ -262,10 +272,11 @@ struct LockFile { impl LockFile { fn new(path: PathBuf) -> PolarsResult { - if File::create(&path).is_ok() { - Ok(Self { path }) - } else { - polars_bail!(ComputeError: "could not create lockfile") + match File::create(&path) { + Ok(_) => Ok(Self { path }), + Err(e) => { + polars_bail!(ComputeError: "could not create lockfile: {e}") + }, } } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index 08a29ebbec6c..e29d1c3e471c 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -7,6 +7,7 @@ use std::vec; use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_ops::prelude::CrossJoin as CrossJoinTrait; +use polars_utils::arena::Node; use smartstring::alias::String as SmartString; use crate::operators::{ @@ -19,19 +20,28 @@ pub struct CrossJoin { chunks: Vec, suffix: SmartString, swapped: bool, + node: Node, } impl CrossJoin { - pub(crate) fn new(suffix: SmartString, swapped: bool) -> Self { + pub(crate) fn new(suffix: SmartString, swapped: bool, node: Node) -> Self { CrossJoin { chunks: vec![], suffix, swapped, + node, } } } impl Sink for CrossJoin { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { self.chunks.push(chunk); Ok(SinkResult::CanHaveMoreInput) diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index c4972b05f659..f60dc306c0ac 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -1,19 +1,20 @@ use std::any::Any; use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use arrow::array::{ArrayRef, BinaryArray}; +use arrow::array::BinaryArray; use hashbrown::hash_map::RawEntryMut; -use polars_core::datatypes::ChunkId; -use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; +use polars_utils::arena::Node; use polars_utils::hashing::hash_to_partition; +use polars_utils::idx_vec::UnitVec; +use polars_utils::index::ChunkId; use polars_utils::slice::GetSaferUnchecked; +use polars_utils::unitvec; use super::*; -use crate::executors::sinks::joins::inner_left::GenericJoinProbe; +use crate::executors::sinks::joins::generic_probe_inner_left::GenericJoinProbe; use crate::executors::sinks::utils::{hash_rows, load_vec}; use crate::executors::sinks::HASHMAP_INIT_SIZE; use crate::expressions::PhysicalPipedExpr; @@ -62,7 +63,7 @@ pub struct GenericBuild { hb: RandomState, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table - hash_tables: Vec>>, + hash_tables: Vec>>, // the columns that will be joined on join_columns_left: Arc>>, @@ -75,6 +76,7 @@ pub struct GenericBuild { // the join order is swapped to ensure we hash the smaller table swapped: bool, join_nulls: bool, + node: Node, } impl GenericBuild { @@ -85,6 +87,7 @@ impl GenericBuild { join_columns_left: Arc>>, join_columns_right: Arc>>, join_nulls: bool, + node: Node, ) -> Self { let hb: RandomState = Default::default(); let partitions = _set_partition_size(); @@ -102,6 +105,7 @@ impl GenericBuild { hash_tables, hashes: vec![], join_nulls, + node, } } } @@ -164,6 +168,13 @@ impl GenericBuild { } impl Sink for GenericBuild { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { // we do some juggling here so that we don't // end up with empty chunks @@ -196,11 +207,11 @@ impl Sink for GenericBuild { compare_fn(key, *h, &self.materialized_join_cols, row) }); - let payload = [current_chunk_offset, current_df_idx]; + let payload = ChunkId::store(current_chunk_offset, current_df_idx); match entry { RawEntryMut::Vacant(entry) => { let key = Key::new(*h, current_chunk_offset, current_df_idx); - entry.insert(key, vec![payload]); + entry.insert(key, unitvec![payload]); }, RawEntryMut::Occupied(mut entry) => { entry.get_mut().push(payload); @@ -253,22 +264,25 @@ impl Sink for GenericBuild { match entry { RawEntryMut::Vacant(entry) => { - let [chunk_idx, df_idx] = unsafe { val.get_unchecked_release(0) }; + let chunk_id = unsafe { val.get_unchecked_release(0) }; + let (chunk_idx, df_idx) = chunk_id.extract(); let new_chunk_idx = chunk_idx + chunks_offset; - let key = Key::new(h, new_chunk_idx, *df_idx); - let mut payload = vec![[new_chunk_idx, *df_idx]]; + let key = Key::new(h, new_chunk_idx, df_idx); + let mut payload = unitvec![ChunkId::store(new_chunk_idx, df_idx)]; if val.len() > 1 { - let iter = val[1..].iter().map(|[chunk_idx, val_idx]| { - [*chunk_idx + chunks_offset, *val_idx] + let iter = val[1..].iter().map(|chunk_id| { + let (chunk_idx, val_idx) = chunk_id.extract(); + ChunkId::store(chunk_idx + chunks_offset, val_idx) }); payload.extend(iter); } entry.insert(key, payload); }, RawEntryMut::Occupied(mut entry) => { - let iter = val - .iter() - .map(|[chunk_idx, val_idx]| [*chunk_idx + chunks_offset, *val_idx]); + let iter = val.iter().map(|chunk_id| { + let (chunk_idx, val_idx) = chunk_id.extract(); + ChunkId::store(chunk_idx + chunks_offset, val_idx) + }); entry.get_mut().extend(iter); }, } @@ -284,6 +298,7 @@ impl Sink for GenericBuild { self.join_columns_left.clone(), self.join_columns_right.clone(), self.join_nulls, + self.node, ); new.hb = self.hb.clone(); Box::new(new) diff --git a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs similarity index 97% rename from crates/polars-pipe/src/executors/sinks/joins/inner_left.rs rename to crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs index ce4f56c95967..14869a0abffd 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -1,17 +1,17 @@ use std::borrow::Cow; -use std::sync::Arc; -use arrow::array::{Array, ArrayRef, BinaryArray}; +use arrow::array::{Array, BinaryArray}; use arrow::compute::utils::combine_validities_and; -use polars_core::datatypes::ChunkId; -use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::series::IsSorted; +use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; use polars_ops::prelude::JoinType; use polars_row::RowsEncoded; use polars_utils::hashing::hash_to_partition; +use polars_utils::idx_vec::UnitVec; +use polars_utils::index::ChunkId; use polars_utils::nulls::IsNull; use polars_utils::slice::GetSaferUnchecked; use smartstring::alias::String as SmartString; @@ -38,7 +38,7 @@ pub struct GenericJoinProbe { hb: RandomState, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table - hash_tables: Arc>>>, + hash_tables: Arc>>>, // the columns that will be joined on join_columns_right: Arc>>, @@ -88,7 +88,7 @@ impl GenericJoinProbe { materialized_join_cols: Arc>>, suffix: Arc, hb: RandomState, - hash_tables: Arc>>>, + hash_tables: Arc>>>, join_columns_left: Arc>>, join_columns_right: Arc>>, swapped_or_left: bool, @@ -171,7 +171,7 @@ impl GenericJoinProbe { } polars_row::convert_columns_amortized_no_order(&self.join_columns, &mut self.current_rows); - // safety: we keep rows-encode alive + // SAFETY: we keep rows-encode alive let array = unsafe { self.current_rows.borrow_array() }; Ok(if self.join_nulls { array @@ -197,7 +197,7 @@ impl GenericJoinProbe { out }, Some(names) => unsafe { - // safety: + // SAFETY: // if we have duplicate names, we overwrite // them in the next snippet left_df diff --git a/crates/polars-pipe/src/executors/sinks/joins/mod.rs b/crates/polars-pipe/src/executors/sinks/joins/mod.rs index f906f5f1d190..5da9cfd715c2 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/mod.rs @@ -1,7 +1,7 @@ #[cfg(feature = "cross_join")] mod cross; mod generic_build; -mod inner_left; +mod generic_probe_inner_left; #[cfg(feature = "cross_join")] pub(crate) use cross::*; diff --git a/crates/polars-pipe/src/executors/sinks/ordered.rs b/crates/polars-pipe/src/executors/sinks/ordered.rs index 156dad5f9a9e..083906125687 100644 --- a/crates/polars-pipe/src/executors/sinks/ordered.rs +++ b/crates/polars-pipe/src/executors/sinks/ordered.rs @@ -53,6 +53,7 @@ impl Sink for OrderedSink { ))); } self.sort(); + let chunks = std::mem::take(&mut self.chunks); Ok(FinalizedSink::Finished(chunks_to_df_unchecked(chunks))) } diff --git a/crates/polars-pipe/src/executors/sinks/output/csv.rs b/crates/polars-pipe/src/executors/sinks/output/csv.rs index 401933ba4e9c..053ff68b40fa 100644 --- a/crates/polars-pipe/src/executors/sinks/output/csv.rs +++ b/crates/polars-pipe/src/executors/sinks/output/csv.rs @@ -56,6 +56,6 @@ impl SinkWriter for polars_io::csv::BatchedWriter { } fn _finish(&mut self) -> PolarsResult<()> { - Ok(()) + self.finish() } } diff --git a/crates/polars-pipe/src/executors/sinks/output/file_sink.rs b/crates/polars-pipe/src/executors/sinks/output/file_sink.rs index a4ee0bb3d07b..937d88458ed9 100644 --- a/crates/polars-pipe/src/executors/sinks/output/file_sink.rs +++ b/crates/polars-pipe/src/executors/sinks/output/file_sink.rs @@ -4,10 +4,13 @@ use std::thread::JoinHandle; use crossbeam_channel::{Receiver, Sender}; use polars_core::prelude::*; -use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, StreamingVstacker, +}; pub(super) trait SinkWriter { fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()>; + fn _finish(&mut self) -> PolarsResult<()>; } @@ -24,6 +27,7 @@ pub(super) fn init_writer_thread( // keep chunks around until all chunks per sink are written // then we write them all at once. let mut chunks = Vec::with_capacity(morsels_per_sink); + let mut vstacker = StreamingVstacker::default(); while let Ok(chunk) = receiver.recv() { // `last_write` indicates if all chunks are processed, e.g. this is the last write. @@ -40,13 +44,26 @@ pub(super) fn init_writer_thread( chunks.sort_by_key(|chunk| chunk.chunk_index); } - for chunk in chunks.iter() { - writer._write_batch(&chunk.data).unwrap() + for chunk in chunks.drain(0..) { + for mut df in vstacker.add(chunk.data) { + // The dataframe may only be a single, large chunk, in + // which case we don't want to bother with copying it... + if df.n_chunks() > 1 { + df.as_single_chunk(); + } + writer._write_batch(&df).unwrap(); + } } // all chunks are written remove them chunks.clear(); if last_write { + if let Some(mut df) = vstacker.finish() { + if df.n_chunks() > 1 { + df.as_single_chunk(); + } + writer._write_batch(&df).unwrap(); + } writer._finish().unwrap(); return; } diff --git a/crates/polars-pipe/src/executors/sinks/output/ipc.rs b/crates/polars-pipe/src/executors/sinks/output/ipc.rs index f7cbab92248a..e0a479f32966 100644 --- a/crates/polars-pipe/src/executors/sinks/output/ipc.rs +++ b/crates/polars-pipe/src/executors/sinks/output/ipc.rs @@ -1,5 +1,4 @@ use std::path::Path; -use std::sync::Arc; use crossbeam_channel::bounded; use polars_core::prelude::*; diff --git a/crates/polars-pipe/src/executors/sinks/slice.rs b/crates/polars-pipe/src/executors/sinks/slice.rs index 6bab970f2dff..a5e1c0e24aca 100644 --- a/crates/polars-pipe/src/executors/sinks/slice.rs +++ b/crates/polars-pipe/src/executors/sinks/slice.rs @@ -3,6 +3,8 @@ use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex}; use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; use polars_utils::atomic::SyncCounter; use crate::operators::{ @@ -15,6 +17,7 @@ pub struct SliceSink { current_len: SyncCounter, len: usize, chunks: Arc>>, + schema: SchemaRef, } impl Clone for SliceSink { @@ -24,18 +27,20 @@ impl Clone for SliceSink { current_len: self.current_len.clone(), len: self.len, chunks: self.chunks.clone(), + schema: self.schema.clone(), } } } impl SliceSink { - pub fn new(offset: u64, len: usize) -> SliceSink { + pub fn new(offset: u64, len: usize, schema: SchemaRef) -> SliceSink { let offset = SyncCounter::new(offset as usize); SliceSink { offset, current_len: SyncCounter::new(0), len, chunks: Default::default(), + schema, } } @@ -87,7 +92,13 @@ impl Sink for SliceSink { self.sort(); let chunks = std::mem::take(&mut self.chunks); let mut chunks = chunks.lock().unwrap(); - let chunks = std::mem::take(chunks.as_mut()); + let chunks: Vec = std::mem::take(chunks.as_mut()); + if chunks.is_empty() { + return Ok(FinalizedSink::Finished(DataFrame::from( + self.schema.as_ref(), + ))); + } + let df = chunks_to_df_unchecked(chunks); let offset = self.offset.load(Ordering::Acquire) as i64; diff --git a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs index dd77a1cf8a69..60547ec6c076 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs @@ -4,7 +4,9 @@ use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use crossbeam_queue::SegQueue; use polars_core::prelude::*; use polars_core::series::IsSorted; -use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_core::utils::{ + accumulate_dataframes_vertical_unchecked, accumulate_dataframes_vertical_unchecked_optional, +}; use polars_core::POOL; use polars_io::ipc::IpcReader; use polars_io::SerReader; @@ -54,7 +56,8 @@ impl PartitionSpillBuf { // so we pop no more than the current size. let pop_max = len; let iter = (0..pop_max).flat_map(|_| self.chunks.pop()); - Some(accumulate_dataframes_vertical_unchecked(iter)) + // Due to race conditions, the chunks can already be popped, so we use optional. + accumulate_dataframes_vertical_unchecked_optional(iter) } else { None } diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 671632955a03..b054282a3257 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -108,7 +108,7 @@ impl SortSink { // expensive let df = accumulate_dataframes_vertical_unchecked(self.chunks.drain(..)); if df.height() > 0 { - // safety: we just asserted height > 0 + // SAFETY: we just asserted height > 0 let sample = unsafe { let s = &df.get_columns()[self.sort_idx]; s.to_physical_repr().get_unchecked(0).into_static().unwrap() @@ -173,7 +173,7 @@ impl Sink for SortSink { let lock = self.io_thread.read().unwrap(); let io_thread = lock.as_ref().unwrap(); - let dist = Series::from_any_values("", &self.dist_sample, false).unwrap(); + let dist = Series::from_any_values("", &self.dist_sample, true).unwrap(); let dist = dist.sort_with(SortOptions { descending: self.sort_args.descending[0], nulls_last: self.sort_args.nulls_last, diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 785c8e10f28f..bf659bd1598c 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,6 +1,6 @@ use std::any::Any; -use arrow::array::{ArrayRef, BinaryArray}; +use arrow::array::BinaryArray; use polars_core::prelude::sort::_broadcast_descending; use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; @@ -81,7 +81,7 @@ fn finalize_dataframe( assert_eq!(encoded.chunks().len(), 1); let arr = encoded.downcast_iter().next().unwrap(); - // safety + // SAFETY: // temporary extend lifetime // this is safe as the lifetime in rows stays bound to this scope let arrays = { @@ -230,7 +230,7 @@ impl SortSinkMultiple { }; debug_assert_eq!(column.chunks().len(), 1); - // Safety: length is correct + // SAFETY: length is correct unsafe { chunk.data.with_column_unchecked(column) }; Ok(()) } diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index d760f91ca633..468cab6d8f75 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -79,6 +79,7 @@ impl CsvSource { .with_rechunk(false) .with_chunk_size(chunk_size) .with_row_index(file_options.row_index) + .with_n_threads(options.n_threads) .with_try_parse_dates(options.try_parse_dates) .truncate_ragged_lines(options.truncate_ragged_lines) .raise_if_empty(options.raise_if_empty); diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs index 4789eae6281f..55a4975970b8 100644 --- a/crates/polars-pipe/src/operators/chunks.rs +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -30,3 +30,150 @@ impl DataChunk { pub(crate) fn chunks_to_df_unchecked(chunks: Vec) -> DataFrame { accumulate_dataframes_vertical_unchecked(chunks.into_iter().map(|c| c.data)) } + +/// Combine a series of `DataFrame`s, and if they're small enough, combine them +/// into larger `DataFrame`s using `vstack`. This allows the caller to turn them +/// into contiguous memory allocations so that we don't suffer from overhead of +/// many small writes. The assumption is that added `DataFrame`s are already in +/// the correct order, and can therefore be combined. +/// +/// The benefit of having a series of `DataFrame` that are e.g. 4MB each that +/// are then made contiguous is that you're not using a lot of memory (an extra +/// 4MB), but you're still doing better than if you had a series of of 2KB +/// `DataFrame`s. +/// +/// Changing the `DataFrame` into contiguous chunks is the caller's +/// responsibility. +#[cfg(feature = "parquet")] +#[derive(Clone)] +pub(crate) struct StreamingVstacker { + current_dataframe: Option, + /// How big should resulting chunks be, if possible? + output_chunk_size: usize, +} + +#[cfg(feature = "parquet")] +impl StreamingVstacker { + /// Create a new instance. + pub fn new(output_chunk_size: usize) -> Self { + Self { + current_dataframe: None, + output_chunk_size, + } + } + + /// Add another `DataFrame`, return (potentially combined) `DataFrame`s that + /// result, if any. + pub fn add(&mut self, next_frame: DataFrame) -> impl Iterator { + let mut result: [Option; 2] = [None, None]; + + // If the next chunk is too large, we probably don't want make copies of + // it if a caller does as_single_chunk(), so we flush in advance. + if self.current_dataframe.is_some() + && next_frame.estimated_size() > self.output_chunk_size / 4 + { + result[0] = self.flush(); + } + + if let Some(ref mut current_frame) = self.current_dataframe { + current_frame + .vstack_mut(&next_frame) + .expect("These are chunks from the same dataframe"); + } else { + self.current_dataframe = Some(next_frame); + }; + + if self.current_dataframe.as_ref().unwrap().estimated_size() > self.output_chunk_size { + result[1] = self.flush(); + } + result.into_iter().flatten() + } + + /// Clear and return any cached `DataFrame` data. + #[must_use] + fn flush(&mut self) -> Option { + std::mem::take(&mut self.current_dataframe) + } + + /// Finish and return any remaining cached `DataFrame` data. The only way + /// that `SemicontiguousVstacker` should be cleaned up. + #[must_use] + pub fn finish(mut self) -> Option { + self.flush() + } +} + +#[cfg(feature = "parquet")] +impl Default for StreamingVstacker { + /// 4 MB was chosen based on some empirical experiments that showed it to + /// be decently faster than lower or higher values, and it's small enough + /// it won't impact memory usage significantly. + fn default() -> Self { + StreamingVstacker::new(4 * 1024 * 1024) + } +} + +#[cfg(test)] +#[cfg(feature = "parquet")] +mod test { + use super::*; + + /// DataFrames get merged into chunks that are bigger than the specified + /// size when possible. + #[test] + fn semicontiguous_vstacker_merges() { + let test = semicontiguous_vstacker_merges_impl; + test(vec![10]); + test(vec![10, 10, 10, 10, 10, 10, 10]); + test(vec![10, 40, 10, 10, 10, 10]); + test(vec![40, 10, 10, 40, 10, 10, 40]); + test(vec![50, 50, 50]); + } + + /// Eventually would be nice to drive this with proptest. + fn semicontiguous_vstacker_merges_impl(df_lengths: Vec) { + // Convert the lengths into a series of DataFrames: + let mut vstacker = StreamingVstacker::new(4096); + let dfs: Vec = df_lengths + .iter() + .enumerate() + .map(|(i, length)| { + let series = Series::new("val", vec![i as u64; *length]); + DataFrame::new(vec![series]).unwrap() + }) + .collect(); + + // Combine the DataFrames using a SemicontiguousVstacker: + let mut results = vec![]; + for (i, df) in dfs.iter().enumerate() { + for mut result_df in vstacker.add(df.clone()) { + result_df.as_single_chunk(); + results.push((i, result_df)); + } + } + if let Some(mut result_df) = vstacker.finish() { + result_df.as_single_chunk(); + results.push((df_lengths.len() - 1, result_df)); + } + + // Make sure the lengths are as sufficiently large, and the chunks + // were merged, the whole point of the exercise: + for (original_idx, result_df) in &results { + if result_df.height() < 40 { + // This means either this was the last df, or the next one + // was big enough we decided not to aggregate. + if *original_idx < results.len() - 1 { + assert!(dfs[original_idx + 1].height() > 10); + } + } + // Make sure all result DataFrames only have a single chunk. + assert_eq!(result_df.get_columns()[0].chunk_lengths().len(), 1); + } + + // Make sure the data was preserved: + assert_eq!( + accumulate_dataframes_vertical_unchecked(dfs.into_iter()), + accumulate_dataframes_vertical_unchecked(results.into_iter().map(|(_, df)| df)), + ); + } +} diff --git a/crates/polars-pipe/src/operators/operator.rs b/crates/polars-pipe/src/operators/operator.rs index 43d621edb03f..9082728a9fdb 100644 --- a/crates/polars-pipe/src/operators/operator.rs +++ b/crates/polars-pipe/src/operators/operator.rs @@ -2,6 +2,7 @@ use super::*; pub enum OperatorResult { /// needs to be called again with new chunk. + /// Or in case of `flush` needs to be called again. NeedsNewData, /// needs to be called again with same chunk. HaveMoreOutPut(DataChunk), @@ -16,6 +17,14 @@ pub trait Operator: Send + Sync { chunk: &DataChunk, ) -> PolarsResult; + fn flush(&mut self) -> PolarsResult { + unimplemented!() + } + + fn must_flush(&self) -> bool { + false + } + fn split(&self, thread_no: usize) -> Box; fn fmt(&self) -> &str; diff --git a/crates/polars-pipe/src/operators/sink.rs b/crates/polars-pipe/src/operators/sink.rs index 19b60aef1772..6f803ceb1f5d 100644 --- a/crates/polars-pipe/src/operators/sink.rs +++ b/crates/polars-pipe/src/operators/sink.rs @@ -1,4 +1,7 @@ use std::any::Any; +use std::fmt::{Debug, Formatter}; + +use polars_utils::arena::Node; use super::*; @@ -14,6 +17,17 @@ pub enum FinalizedSink { Source(Box), } +impl Debug for FinalizedSink { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + FinalizedSink::Finished(_) => "finished", + FinalizedSink::Operator(_) => "operator", + FinalizedSink::Source(_) => "source", + }; + write!(f, "{s}") + } +} + pub trait Sink: Send + Sync { fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult; @@ -26,4 +40,13 @@ pub trait Sink: Send + Sync { fn as_any(&mut self) -> &mut dyn Any; fn fmt(&self) -> &str; + + fn is_join_build(&self) -> bool { + false + } + + // Only implemented for Join sinks + fn node(&self) -> Node { + unimplemented!() + } } diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index d8566f46f23f..2e0375353759 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::rc::Rc; -use std::sync::Arc; use hashbrown::hash_map::Entry; use polars_core::prelude::*; @@ -248,8 +247,10 @@ where match &options.args.how { #[cfg(feature = "cross_join")] - JoinType::Cross => Box::new(CrossJoin::new(options.args.suffix().into(), swapped)) - as Box, + JoinType::Cross => { + Box::new(CrossJoin::new(options.args.suffix().into(), swapped, node)) + as Box + }, join_type @ JoinType::Inner | join_type @ JoinType::Left => { let input_schema_left = lp_arena.get(*input_left).schema(lp_arena); let join_columns_left = Arc::new(exprs_to_physical( @@ -279,13 +280,15 @@ where join_columns_left, join_columns_right, options.args.join_nulls, + node, )) as Box }, _ => unimplemented!(), } }, - Slice { offset, len, .. } => { - let slice = SliceSink::new(*offset as u64, *len as usize); + Slice { input, offset, len } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + let slice = SliceSink::new(*offset as u64, *len as usize, input_schema.into_owned()); Box::new(slice) as Box }, Sort { diff --git a/crates/polars-pipe/src/pipeline/dispatcher.rs b/crates/polars-pipe/src/pipeline/dispatcher.rs index 7fc7cdb9a35b..14dfdf4ad140 100644 --- a/crates/polars-pipe/src/pipeline/dispatcher.rs +++ b/crates/polars-pipe/src/pipeline/dispatcher.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use std::collections::VecDeque; +use std::collections::{BTreeSet, VecDeque}; use std::fmt::{Debug, Formatter}; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -308,9 +308,12 @@ impl PipeLine { sink: &mut Box, ) -> PolarsResult { debug_assert!(!operators.is_empty()); + + // Stack based operator execution. let mut in_process = vec![]; let operator_offset = 0usize; in_process.push((operator_offset, chunk)); + let mut needs_flush = BTreeSet::new(); while let Some((op_i, chunk)) = in_process.pop() { match operators.get_mut(op_i) { @@ -321,16 +324,20 @@ impl PipeLine { }, Some(op) => { match op.execute(ec, &chunk)? { - OperatorResult::Finished(chunk) => in_process.push((op_i + 1, chunk)), + OperatorResult::Finished(chunk) => { + if op.must_flush() { + let _ = needs_flush.insert(op_i); + } + in_process.push((op_i + 1, chunk)) + }, OperatorResult::HaveMoreOutPut(output_chunk) => { - // first on the stack the next operator call + // Push the next operator call with the same chunk on the stack in_process.push((op_i, chunk)); - // but first push the output in the next operator - // is a join can produce many rows, we want the filter to - // be executed in between. - // or sink into a slice so that we get sink::finished - // before we grow the stack with ever more coming chunks + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks in_process.push((op_i + 1, output_chunk)); }, OperatorResult::NeedsNewData => { @@ -340,6 +347,78 @@ impl PipeLine { }, } } + + // Stack based flushing + operator execution. + if !needs_flush.is_empty() { + drop(in_process); + let mut in_process = vec![]; + + for op_i in needs_flush.into_iter() { + // Push all operators that need flushing on the stack. + // The `None` indicates that we have no `chunk` input, so we `flush`. + // `Some(chunk)` is the pushing branch + in_process.push((op_i, None)); + + // Next we immediately pop and determine the order of execution below. + // This is to ensure that all operators below upper operators are completely + // flushed when the `flush` is called in higher operators. As operators can `flush` + // multiple times. + while let Some((op_i, chunk)) = in_process.pop() { + match chunk { + // The branch for flushing. + None => { + let op = operators.get_mut(op_i).unwrap(); + match op.flush()? { + OperatorResult::Finished(chunk) => { + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(chunk) => { + // Ensure it is flushed again + in_process.push((op_i, None)); + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + _ => unreachable!(), + } + }, + // The branch for pushing data in the operators. + // This is the same as the default stack exectuor, except now it pushes + // `Some(chunk)` instead of `chunk`. + Some(chunk) => { + match operators.get_mut(op_i) { + None => { + if let SinkResult::Finished = sink.sink(ec, chunk)? { + return Ok(SinkResult::Finished); + } + }, + Some(op) => { + match op.execute(ec, &chunk)? { + OperatorResult::Finished(chunk) => { + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(output_chunk) => { + // Push the next operator call with the same chunk on the stack + in_process.push((op_i, Some(chunk))); + + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks + in_process.push((op_i + 1, Some(output_chunk))); + }, + OperatorResult::NeedsNewData => { + // Done, take another chunk from the stack + }, + } + }, + } + }, + } + } + } + } + Ok(SinkResult::CanHaveMoreInput) } @@ -423,8 +502,38 @@ impl PipeLine { let mut pipeline = pipeline_q.borrow_mut().pop_front().unwrap(); let (count, mut sink) = pipeline.run_pipeline_no_finalize(ec, pipeline_q.clone())?; - reduced_sink.combine(sink.as_mut()); - shared_sink_count = count; + // This branch is hit when we have a Union of joins. + // The build side must be converted into an operator and replaced in the next pipeline. + + // Check either: + // 1. There can be a union source that sinks into a single join: + // scan_parquet(*) -> join B + // 2. There can be a union of joins + // C - JOIN A, B + // concat (A, B, C) + // + // So to ensure that we don't finalize we check + // - They are not both join builds + // - If they are both join builds, check they are note the same build, otherwise + // we must call the `combine` branch. + if sink.is_join_build() + && (!reduced_sink.is_join_build() || (sink.node() != reduced_sink.node())) + { + let FinalizedSink::Operator(op) = sink.finalize(ec)? else { + unreachable!() + }; + let mut q = pipeline_q.borrow_mut(); + let Some(node) = pipeline.sink_nodes.pop() else { + unreachable!() + }; + + for probe_side in q.iter_mut() { + let _ = probe_side.replace_operator(op.as_ref(), node); + } + } else { + reduced_sink.combine(sink.as_mut()); + shared_sink_count = count; + } } } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index ebe0735f7a41..dc3c5faf610a 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -54,7 +54,7 @@ serde = [ ] streaming = [] parquet = ["polars-io/parquet", "polars-parquet"] -async = ["polars-io/async"] +async = ["polars-io/async", "futures"] cloud = ["async", "polars-io/cloud"] ipc = ["polars-io/ipc"] json = ["polars-io/json", "polars-json"] @@ -130,8 +130,9 @@ ewma = ["polars-ops/ewma"] dot_diagram = [] unique_counts = ["polars-ops/unique_counts"] log = ["polars-ops/log"] -chunked_ids = ["polars-core/chunked_ids"] +chunked_ids = [] list_to_struct = ["polars-ops/list_to_struct"] +array_to_struct = ["polars-ops/array_to_struct"] row_hash = ["polars-core/row_hash", "polars-ops/hash"] reinterpret = ["polars-core/reinterpret", "polars-ops/reinterpret"] string_pad = ["polars-ops/string_pad"] diff --git a/crates/polars-plan/src/dot.rs b/crates/polars-plan/src/dot.rs index 22e1272c4247..735813721e84 100644 --- a/crates/polars-plan/src/dot.rs +++ b/crates/polars-plan/src/dot.rs @@ -5,7 +5,6 @@ use std::path::PathBuf; use polars_core::prelude::*; use crate::prelude::*; -use crate::utils::expr_to_leaf_column_names; impl Expr { /// Get a dot language representation of the Expression. @@ -48,7 +47,7 @@ impl Expr { let current_node = format!( r#"BINARY left _; - op {op:?}, + op {op:?}; right: _ [{branch},{id}]"#, ); @@ -166,7 +165,7 @@ impl LogicalPlan { let fmt = if *count == usize::MAX { Cow::Borrowed("CACHE") } else { - Cow::Owned(format!("CACHE: {}times", *count)) + Cow::Owned(format!("CACHE: {} times", *count)) }; let current_node = DotNode { branch: *cache_id, @@ -313,7 +312,7 @@ impl LogicalPlan { } let pred = fmt_predicate(selection.as_ref()); - let fmt = format!("TABLE\nπ {n_columns}/{total_columns};\nσ {pred};"); + let fmt = format!("TABLE\nπ {n_columns}/{total_columns};\nσ {pred}"); let current_node = DotNode { branch, id, @@ -357,7 +356,7 @@ impl LogicalPlan { } => { let fmt = format!( r#"JOIN {} - left {:?}; + left: {:?}; right: {:?}"#, options.args.how, left_on, right_on ); @@ -406,7 +405,7 @@ impl LogicalPlan { input.dot(acc_str, (branch, id + 1), current_node, id_map) }, Error { err, .. } => { - let fmt = format!("{:?}", &**err); + let fmt = format!("{:?}", &err.0); let current_node = DotNode { branch, id, diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 285872364932..b00347ba8007 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -1,12 +1,13 @@ -use polars_core::prelude::SortOptions; +use polars_core::prelude::*; +#[cfg(feature = "array_to_struct")] +use polars_ops::chunked_array::array::{ + arr_default_struct_name_gen, ArrToStructNameGenerator, ToStruct, +}; -use crate::dsl::function_expr::{ArrayFunction, FunctionExpr}; +use crate::dsl::function_expr::ArrayFunction; use crate::prelude::*; -/// Specialized expressions for [`Series`][Series] of [`DataType::List`][DataType::List]. -/// -/// [Series]: polars_core::prelude::Series -/// [DataType::List]: polars_core::prelude::DataType::List +/// Specialized expressions for [`Series`] of [`DataType::Array`]. pub struct ArrayNameSpace(pub Expr); impl ArrayNameSpace { @@ -28,6 +29,24 @@ impl ArrayNameSpace { .map_private(FunctionExpr::ArrayExpr(ArrayFunction::Sum)) } + /// Compute the std of the items in every subarray. + pub fn std(self, ddof: u8) -> Expr { + self.0 + .map_private(FunctionExpr::ArrayExpr(ArrayFunction::Std(ddof))) + } + + /// Compute the var of the items in every subarray. + pub fn var(self, ddof: u8) -> Expr { + self.0 + .map_private(FunctionExpr::ArrayExpr(ArrayFunction::Var(ddof))) + } + + /// Compute the median of the items in every subarray. + pub fn median(self) -> Expr { + self.0 + .map_private(FunctionExpr::ArrayExpr(ArrayFunction::Median)) + } + /// Keep only the unique values in every sub-array. pub fn unique(self) -> Expr { self.0 @@ -132,4 +151,40 @@ impl ArrayNameSpace { options }) } + + #[cfg(feature = "array_to_struct")] + pub fn to_struct(self, name_generator: Option) -> Expr { + self.0 + .map( + move |s| { + s.array()? + .to_struct(name_generator.clone()) + .map(|s| Some(s.into_series())) + }, + GetOutput::map_dtype(move |dt: &DataType| { + let DataType::Array(inner, width) = dt else { + panic!("Only array dtype is expected for `arr.to_struct`.") + }; + + let fields = (0..*width) + .map(|i| { + let name = arr_default_struct_name_gen(i); + Field::from_owned(name, inner.as_ref().clone()) + }) + .collect(); + DataType::Struct(fields) + }), + ) + .with_fmt("arr.to_struct") + } + + /// Shift every sub-array. + pub fn shift(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ArrayExpr(ArrayFunction::Shift), + &[n], + false, + false, + ) + } } diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index c8a6dc368391..1a395ac1cb6f 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -1,4 +1,3 @@ -use super::function_expr::BinaryFunction; use super::*; /// Specialized expressions for [`Series`] of [`DataType::String`]. pub struct BinaryNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index b91deca57e9d..cf7cc2a31fbc 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -6,7 +6,6 @@ use polars_core::prelude::*; use serde::{Deserialize, Serialize}; pub use super::expr_dyn_fn::*; -use crate::dsl::function_expr::FunctionExpr; use crate::prelude::*; #[derive(PartialEq, Clone)] diff --git a/crates/polars-plan/src/dsl/from.rs b/crates/polars-plan/src/dsl/from.rs index e815fdb7ffe4..eeaa631521cb 100644 --- a/crates/polars-plan/src/dsl/from.rs +++ b/crates/polars-plan/src/dsl/from.rs @@ -6,8 +6,6 @@ impl From for Expr { } } -pub trait RefString {} - impl From<&str> for Expr { fn from(s: &str) -> Self { col(s) diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 55fb2399bcd4..77b8ac2f68e3 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -11,6 +11,9 @@ pub enum ArrayFunction { Sum, ToList, Unique(bool), + Std(u8), + Var(u8), + Median, #[cfg(feature = "array_any_all")] Any, #[cfg(feature = "array_any_all")] @@ -25,6 +28,7 @@ pub enum ArrayFunction { Contains, #[cfg(feature = "array_count")] CountMatches, + Shift, } impl ArrayFunction { @@ -35,6 +39,9 @@ impl ArrayFunction { Sum => mapper.nested_sum_type(), ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype), Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype), + Std(_) => mapper.map_to_float_dtype(), + Var(_) => mapper.map_to_float_dtype(), + Median => mapper.map_to_float_dtype(), #[cfg(feature = "array_any_all")] Any | All => mapper.with_dtype(DataType::Boolean), Sort(_) => mapper.with_same_dtype(), @@ -46,6 +53,7 @@ impl ArrayFunction { Contains => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "array_count")] CountMatches => mapper.with_dtype(IDX_DTYPE), + Shift => mapper.with_same_dtype(), } } } @@ -67,6 +75,9 @@ impl Display for ArrayFunction { Sum => "sum", ToList => "to_list", Unique(_) => "unique", + Std(_) => "std", + Var(_) => "var", + Median => "median", #[cfg(feature = "array_any_all")] Any => "any", #[cfg(feature = "array_any_all")] @@ -81,6 +92,7 @@ impl Display for ArrayFunction { Contains => "contains", #[cfg(feature = "array_count")] CountMatches => "count_matches", + Shift => "shift", }; write!(f, "arr.{name}") } @@ -95,6 +107,9 @@ impl From for SpecialEq> { Sum => map!(sum), ToList => map!(to_list), Unique(stable) => map!(unique, stable), + Std(ddof) => map!(std, ddof), + Var(ddof) => map!(var, ddof), + Median => map!(median), #[cfg(feature = "array_any_all")] Any => map!(any), #[cfg(feature = "array_any_all")] @@ -109,6 +124,7 @@ impl From for SpecialEq> { Contains => map_as_slice!(contains), #[cfg(feature = "array_count")] CountMatches => map_as_slice!(count_matches), + Shift => map_as_slice!(shift), } } } @@ -125,6 +141,17 @@ pub(super) fn sum(s: &Series) -> PolarsResult { s.array()?.array_sum() } +pub(super) fn std(s: &Series, ddof: u8) -> PolarsResult { + s.array()?.array_std(ddof) +} + +pub(super) fn var(s: &Series, ddof: u8) -> PolarsResult { + s.array()?.array_var(ddof) +} +pub(super) fn median(s: &Series) -> PolarsResult { + s.array()?.array_median() +} + pub(super) fn unique(s: &Series, stable: bool) -> PolarsResult { let ca = s.array()?; let out = if stable { @@ -201,3 +228,10 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult { let ca = s.array()?; ca.array_count_matches(element.get(0).unwrap()) } + +pub(super) fn shift(s: &[Series]) -> PolarsResult { + let ca = s[0].array()?; + let n = &s[1]; + + ca.array_shift(n) +} diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index 1ddef0b8d81a..6bb888fe7233 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -1,5 +1,3 @@ -use std::ops::Not; - use super::*; #[cfg(feature = "is_in")] use crate::wrap; @@ -36,12 +34,25 @@ pub enum BooleanFunction { IsIn, AllHorizontal, AnyHorizontal, + // Also bitwise negate Not, } impl BooleanFunction { pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { - mapper.with_dtype(DataType::Boolean) + match self { + BooleanFunction::Not => { + mapper.try_map_dtype(|dtype| { + match dtype { + DataType::Boolean => Ok(DataType::Boolean), + dt if dt.is_integer() => Ok(dt.clone()), + dt => polars_bail!(InvalidOperation: "dtype {:?} not supported in 'not' operation", dt) + } + }) + + }, + _ => mapper.with_dtype(DataType::Boolean), + } } } @@ -200,5 +211,5 @@ fn all_horizontal(s: &[Series]) -> PolarsResult { } fn not(s: &Series) -> PolarsResult { - Ok(s.bool()?.not().into_series()) + polars_ops::series::negate_bitwise(s) } diff --git a/crates/polars-plan/src/dsl/function_expr/clip.rs b/crates/polars-plan/src/dsl/function_expr/clip.rs index 2f643857e1a2..adae248a8af2 100644 --- a/crates/polars-plan/src/dsl/function_expr/clip.rs +++ b/crates/polars-plan/src/dsl/function_expr/clip.rs @@ -2,9 +2,9 @@ use super::*; pub(super) fn clip(s: &[Series], has_min: bool, has_max: bool) -> PolarsResult { match (has_min, has_max) { - (true, true) => polars_ops::prelude::clip(&s[0], &s[1], &s[2]), - (true, false) => polars_ops::prelude::clip_min(&s[0], &s[1]), - (false, true) => polars_ops::prelude::clip_max(&s[0], &s[1]), + (true, true) => polars_ops::series::clip(&s[0], &s[1], &s[2]), + (true, false) => polars_ops::series::clip_min(&s[0], &s[1]), + (false, true) => polars_ops::series::clip_max(&s[0], &s[1]), _ => unreachable!(), } } diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 6ff69ca2f076..3836124b584d 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -289,10 +289,12 @@ pub(super) fn second(s: &Series) -> PolarsResult { s.second().map(|ca| ca.into_series()) } pub(super) fn millisecond(s: &Series) -> PolarsResult { - s.nanosecond().map(|ca| (ca / 1_000_000).into_series()) + s.nanosecond() + .map(|ca| (ca.wrapping_trunc_div_scalar(1_000_000)).into_series()) } pub(super) fn microsecond(s: &Series) -> PolarsResult { - s.nanosecond().map(|ca| (ca / 1_000).into_series()) + s.nanosecond() + .map(|ca| (ca.wrapping_trunc_div_scalar(1_000)).into_series()) } pub(super) fn nanosecond(s: &Series) -> PolarsResult { s.nanosecond().map(|ca| ca.into_series()) @@ -328,16 +330,12 @@ pub(super) fn to_string(s: &Series, format: &str) -> PolarsResult { #[cfg(feature = "timezones")] pub(super) fn convert_time_zone(s: &Series, time_zone: &TimeZone) -> PolarsResult { match s.dtype() { - DataType::Datetime(_, Some(_)) => { + DataType::Datetime(_, _) => { let mut ca = s.datetime()?.clone(); ca.set_time_zone(time_zone.clone())?; Ok(ca.into_series()) }, - _ => polars_bail!( - ComputeError: - "cannot call `convert_time_zone` on tz-naive; set a time zone first \ - with `replace_time_zone`" - ), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), } } pub(super) fn with_time_unit(s: &Series, tu: TimeUnit) -> PolarsResult { @@ -522,7 +520,7 @@ pub(super) fn duration(s: &[Series], time_unit: TimeUnit) -> PolarsResult PolarsResult PolarsResult PolarsResult> { - polars_ops::prelude::sum_horizontal(s) -} - pub(super) fn max_horizontal(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::max_horizontal(s) } @@ -85,6 +81,14 @@ pub(super) fn min_horizontal(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::min_horizontal(s) } +pub(super) fn sum_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::sum_horizontal(s) +} + +pub(super) fn mean_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::mean_horizontal(s) +} + pub(super) fn drop_nulls(s: &Series) -> PolarsResult { Ok(s.drop_nulls()) } @@ -156,3 +160,18 @@ pub(super) fn reinterpret(s: &Series, signed: bool) -> PolarsResult { pub(super) fn negate(s: &Series) -> PolarsResult { polars_ops::series::negate(s) } + +pub(super) fn extend_constant(s: &[Series]) -> PolarsResult { + let value = &s[1]; + let n = &s[2]; + polars_ensure!(value.len() == 1 && n.len() == 1, ComputeError: "value and n should have unit length."); + let n = n.strict_cast(&DataType::UInt64)?; + let v = value.get(0)?; + let s = &s[0]; + match n.u64()?.get(0) { + Some(n) => s.extend_constant(v, n as usize), + None => { + polars_bail!(ComputeError: "n can not be None for extend_constant.") + }, + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/fused.rs b/crates/polars-plan/src/dsl/function_expr/fused.rs index f9cb93c1b560..a95ac809ebc7 100644 --- a/crates/polars-plan/src/dsl/function_expr/fused.rs +++ b/crates/polars-plan/src/dsl/function_expr/fused.rs @@ -1,5 +1,3 @@ -use std::fmt::{Display, Formatter}; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 3bbd6f43a99c..53a20667eae9 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -24,6 +24,8 @@ pub enum ListFunction { Get, #[cfg(feature = "list_gather")] Gather(bool), + #[cfg(feature = "list_gather")] + GatherEvery, #[cfg(feature = "list_count")] CountMatches, Sum, @@ -31,6 +33,9 @@ pub enum ListFunction { Max, Min, Mean, + Median, + Std(u8), + Var(u8), ArgMin, ArgMax, #[cfg(feature = "diff")] @@ -41,6 +46,7 @@ pub enum ListFunction { Sort(SortOptions), Reverse, Unique(bool), + NUnique, #[cfg(feature = "list_sets")] SetOperation(SetOperation), #[cfg(feature = "list_any_all")] @@ -68,12 +74,17 @@ impl ListFunction { Get => mapper.map_to_list_and_array_inner_dtype(), #[cfg(feature = "list_gather")] Gather(_) => mapper.with_same_dtype(), + #[cfg(feature = "list_gather")] + GatherEvery => mapper.with_same_dtype(), #[cfg(feature = "list_count")] CountMatches => mapper.with_dtype(IDX_DTYPE), Sum => mapper.nested_sum_type(), Min => mapper.map_to_list_and_array_inner_dtype(), Max => mapper.map_to_list_and_array_inner_dtype(), Mean => mapper.with_dtype(DataType::Float64), + Median => mapper.map_to_float_dtype(), + Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration.. + Var(_) => mapper.map_to_float_dtype(), ArgMin => mapper.with_dtype(IDX_DTYPE), ArgMax => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "diff")] @@ -91,6 +102,7 @@ impl ListFunction { Join(_) => mapper.with_dtype(DataType::String), #[cfg(feature = "dtype-array")] ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), + NUnique => mapper.with_dtype(IDX_DTYPE), } } } @@ -127,12 +139,17 @@ impl Display for ListFunction { Get => "get", #[cfg(feature = "list_gather")] Gather(_) => "gather", + #[cfg(feature = "list_gather")] + GatherEvery => "gather_every", #[cfg(feature = "list_count")] CountMatches => "count_matches", Sum => "sum", Min => "min", Max => "max", Mean => "mean", + Median => "median", + Std(_) => "std", + Var(_) => "var", ArgMin => "arg_min", ArgMax => "arg_max", #[cfg(feature = "diff")] @@ -147,6 +164,7 @@ impl Display for ListFunction { "unique" } }, + NUnique => "n_unique", #[cfg(feature = "list_sets")] SetOperation(s) => return write!(f, "list.{s}"), #[cfg(feature = "list_any_all")] @@ -188,6 +206,8 @@ impl From for SpecialEq> { Get => wrap!(get), #[cfg(feature = "list_gather")] Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob), + #[cfg(feature = "list_gather")] + GatherEvery => map_as_slice!(gather_every), #[cfg(feature = "list_count")] CountMatches => map_as_slice!(count_matches), Sum => map!(sum), @@ -195,6 +215,9 @@ impl From for SpecialEq> { Max => map!(max), Min => map!(min), Mean => map!(mean), + Median => map!(median), + Std(ddof) => map!(std, ddof), + Var(ddof) => map!(var, ddof), ArgMin => map!(arg_min), ArgMax => map!(arg_max), #[cfg(feature = "diff")] @@ -211,6 +234,7 @@ impl From for SpecialEq> { Join(ignore_nulls) => map_as_slice!(join, ignore_nulls), #[cfg(feature = "dtype-array")] ToArray(width) => map!(to_array, width), + NUnique => map!(n_unique), } } } @@ -455,6 +479,15 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult } } +#[cfg(feature = "list_gather")] +pub(super) fn gather_every(args: &[Series]) -> PolarsResult { + let ca = &args[0]; + let n = &args[1].strict_cast(&IDX_DTYPE)?; + let offset = &args[2].strict_cast(&IDX_DTYPE)?; + + ca.list()?.lst_gather_every(n.idx()?, offset.idx()?) +} + #[cfg(feature = "list_count")] pub(super) fn count_matches(args: &[Series]) -> PolarsResult { let s = &args[0]; @@ -488,6 +521,18 @@ pub(super) fn mean(s: &Series) -> PolarsResult { Ok(s.list()?.lst_mean()) } +pub(super) fn median(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_median()) +} + +pub(super) fn std(s: &Series, ddof: u8) -> PolarsResult { + Ok(s.list()?.lst_std(ddof)) +} + +pub(super) fn var(s: &Series, ddof: u8) -> PolarsResult { + Ok(s.list()?.lst_var(ddof)) +} + pub(super) fn arg_min(s: &Series) -> PolarsResult { Ok(s.list()?.lst_arg_min().into_series()) } @@ -566,3 +611,7 @@ pub(super) fn to_array(s: &Series, width: usize) -> PolarsResult { let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?; s.cast(&array_dtype) } + +pub(super) fn n_unique(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_n_unique()?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index c7a8273e4142..397959e9980f 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -73,10 +73,6 @@ pub(crate) use correlation::CorrelationMethod; pub(crate) use fused::FusedOperator; pub(super) use list::ListFunction; use polars_core::prelude::*; -#[cfg(feature = "cutqcut")] -use polars_ops::prelude::{cut, qcut}; -#[cfg(feature = "rle")] -use polars_ops::prelude::{rle, rle_id}; #[cfg(feature = "random")] pub(crate) use random::RandomMethod; use schema::FieldsMapper; @@ -303,9 +299,10 @@ pub enum FunctionExpr { ForwardFill { limit: FillNullLimit, }, - SumHorizontal, MaxHorizontal, MinHorizontal, + SumHorizontal, + MeanHorizontal, #[cfg(feature = "ewma")] EwmMean { options: EWMOptions, @@ -328,6 +325,7 @@ pub enum FunctionExpr { }, #[cfg(feature = "reinterpret")] Reinterpret(bool), + ExtendConstant, } impl Hash for FunctionExpr { @@ -378,8 +376,8 @@ impl Hash for FunctionExpr { lib.hash(state); symbol.hash(state); }, - SumHorizontal | MaxHorizontal | MinHorizontal | DropNans | DropNulls | Reverse - | ArgUnique | Shift | ShiftAndFill => {}, + MaxHorizontal | MinHorizontal | SumHorizontal | MeanHorizontal | DropNans + | DropNulls | Reverse | ArgUnique | Shift | ShiftAndFill => {}, #[cfg(feature = "mode")] Mode => {}, #[cfg(feature = "abs")] @@ -534,6 +532,7 @@ impl Hash for FunctionExpr { GatherEvery { n, offset } => (n, offset).hash(state), #[cfg(feature = "reinterpret")] Reinterpret(signed) => signed.hash(state), + ExtendConstant => {}, } } } @@ -690,9 +689,10 @@ impl Display for FunctionExpr { FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"), BackwardFill { .. } => "backward_fill", ForwardFill { .. } => "forward_fill", - SumHorizontal => "sum_horizontal", MaxHorizontal => "max_horizontal", MinHorizontal => "min_horizontal", + SumHorizontal => "sum_horizontal", + MeanHorizontal => "mean_horizontal", #[cfg(feature = "ewma")] EwmMean { .. } => "ewm_mean", #[cfg(feature = "ewma")] @@ -707,6 +707,7 @@ impl Display for FunctionExpr { GatherEvery { .. } => "gather_every", #[cfg(feature = "reinterpret")] Reinterpret(_) => "reinterpret", + ExtendConstant => "extend_constant", }; write!(f, "{s}") } @@ -1046,9 +1047,10 @@ impl From for SpecialEq> { }, BackwardFill { limit } => map!(dispatch::backward_fill, limit), ForwardFill { limit } => map!(dispatch::forward_fill, limit), - SumHorizontal => wrap!(dispatch::sum_horizontal), MaxHorizontal => wrap!(dispatch::max_horizontal), MinHorizontal => wrap!(dispatch::min_horizontal), + SumHorizontal => wrap!(dispatch::sum_horizontal), + MeanHorizontal => wrap!(dispatch::mean_horizontal), #[cfg(feature = "ewma")] EwmMean { options } => map!(ewm::ewm_mean, options), #[cfg(feature = "ewma")] @@ -1063,6 +1065,7 @@ impl From for SpecialEq> { GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset), #[cfg(feature = "reinterpret")] Reinterpret(signed) => map!(dispatch::reinterpret, signed), + ExtendConstant => map_as_slice!(dispatch::extend_constant), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs index bacda8bc45ae..1dd96e3f6af4 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; use polars_time::{datetime_range_impl, ClosedWindow, Duration}; diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index 10cca7f3ccf6..3c61e60259c4 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -1,7 +1,6 @@ #[cfg(feature = "timezones")] use polars_core::chunked_array::temporal::parse_time_zone; use polars_core::prelude::*; -use polars_core::series::Series; use polars_time::{datetime_range_impl, ClosedWindow, Duration}; use super::utils::{ diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs index 3b4206e3ae0a..5344ec0b5ee8 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_core::with_match_physical_integer_polars_type; use polars_ops::series::new_int_range; diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs index ab0508eff3de..dfee18e7e7cc 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -10,7 +10,6 @@ mod utils; use std::fmt::{Display, Formatter}; use polars_core::prelude::*; -use polars_core::series::Series; #[cfg(feature = "temporal")] use polars_time::{ClosedWindow, Duration}; #[cfg(feature = "serde")] diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs index 2d799a471269..991368356cc5 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_time::{time_range_impl, ClosedWindow, Duration}; use super::utils::{ diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index 272ef6f6ba13..67772ef31adf 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -1,3 +1,5 @@ +use polars_time::chunkedarray::*; + use super::*; #[derive(Clone, PartialEq, Debug)] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 505e0493cc1e..6a711353c52a 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -269,9 +269,10 @@ impl FunctionExpr { } => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) }, BackwardFill { .. } => mapper.with_same_dtype(), ForwardFill { .. } => mapper.with_same_dtype(), - SumHorizontal => mapper.map_to_supertype(), MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), + SumHorizontal => mapper.map_to_supertype(), + MeanHorizontal => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] EwmMean { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] @@ -291,6 +292,7 @@ impl FunctionExpr { }; mapper.with_dtype(dt) }, + ExtendConstant => mapper.with_same_dtype(), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/search_sorted.rs b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs index 9d4566dfce65..87933fc7bd6c 100644 --- a/crates/polars-plan/src/dsl/function_expr/search_sorted.rs +++ b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs @@ -1,5 +1,3 @@ -use polars_ops::prelude::search_sorted; - use super::*; pub(super) fn search_sorted_impl(s: &mut [Series], side: SearchSortedSide) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 64e121099d43..13d18d790c63 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -47,6 +47,7 @@ pub enum StringFunction { dtype: DataType, pat: String, }, + #[cfg(feature = "regex")] Find { literal: bool, strict: bool, @@ -139,6 +140,7 @@ impl StringFunction { ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()), #[cfg(feature = "string_to_integer")] ToInteger { .. } => mapper.with_dtype(DataType::Int64), + #[cfg(feature = "regex")] Find { .. } => mapper.with_dtype(DataType::UInt32), #[cfg(feature = "extract_jsonpath")] JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), @@ -206,6 +208,7 @@ impl Display for StringFunction { ExtractGroups { .. } => "extract_groups", #[cfg(feature = "string_to_integer")] ToInteger { .. } => "to_integer", + #[cfg(feature = "regex")] Find { .. } => "find", #[cfg(feature = "extract_jsonpath")] JsonDecode { .. } => "json_decode", @@ -289,6 +292,7 @@ impl From for SpecialEq> { ExtractGroups { pat, dtype } => { map!(strings::extract_groups, &pat, &dtype) }, + #[cfg(feature = "regex")] Find { literal, strict } => map_as_slice!(strings::find, literal, strict), LenBytes => map!(strings::len_bytes), LenChars => map!(strings::len_chars), diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs index c927a14f1e58..440a108a8408 100644 --- a/crates/polars-plan/src/dsl/function_expr/struct_.rs +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -9,6 +9,8 @@ pub enum StructFunction { FieldByIndex(i64), FieldByName(Arc), RenameFields(Arc>), + PrefixFields(Arc), + SuffixFields(Arc), #[cfg(feature = "json")] JsonEncode, } @@ -60,6 +62,32 @@ impl StructFunction { .collect(), ), }), + PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|fld| { + let name = fld.name(); + Field::new(&format!("{prefix}{name}"), fld.data_type().clone()) + }) + .collect(); + Ok(DataType::Struct(fields)) + }, + _ => polars_bail!(op = "prefix_fields", got = dt, expected = "Struct"), + }), + SuffixFields(suffix) => mapper.try_map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|fld| { + let name = fld.name(); + Field::new(&format!("{name}{suffix}"), fld.data_type().clone()) + }) + .collect(); + Ok(DataType::Struct(fields)) + }, + _ => polars_bail!(op = "suffix_fields", got = dt, expected = "Struct"), + }), #[cfg(feature = "json")] JsonEncode => mapper.with_dtype(DataType::String), } @@ -73,6 +101,8 @@ impl Display for StructFunction { FieldByIndex(index) => write!(f, "struct.field_by_index({index})"), FieldByName(name) => write!(f, "struct.field_by_name({name})"), RenameFields(names) => write!(f, "struct.rename_fields({:?})", names), + PrefixFields(_) => write!(f, "name.prefix_fields"), + SuffixFields(_) => write!(f, "name.suffixFields"), #[cfg(feature = "json")] JsonEncode => write!(f, "struct.to_json"), } @@ -86,6 +116,8 @@ impl From for SpecialEq> { FieldByIndex(index) => map!(struct_::get_by_index, index), FieldByName(name) => map!(struct_::get_by_name, name.clone()), RenameFields(names) => map!(struct_::rename_fields, names.clone()), + PrefixFields(prefix) => map!(struct_::prefix_fields, prefix.clone()), + SuffixFields(suffix) => map!(struct_::suffix_fields, suffix.clone()), #[cfg(feature = "json")] JsonEncode => map!(struct_::to_json), } @@ -120,14 +152,45 @@ pub(super) fn rename_fields(s: &Series, names: Arc>) -> PolarsResult StructChunked::new(ca.name(), &fields).map(|ca| ca.into_series()) } +pub(super) fn prefix_fields(s: &Series, prefix: Arc) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields() + .iter() + .map(|s| { + let mut s = s.clone(); + let name = s.name(); + s.rename(&format!("{prefix}{name}")); + s + }) + .collect::>(); + StructChunked::new(ca.name(), &fields).map(|ca| ca.into_series()) +} + +pub(super) fn suffix_fields(s: &Series, suffix: Arc) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields() + .iter() + .map(|s| { + let mut s = s.clone(); + let name = s.name(); + s.rename(&format!("{name}{suffix}")); + s + }) + .collect::>(); + StructChunked::new(ca.name(), &fields).map(|ca| ca.into_series()) +} + #[cfg(feature = "json")] pub(super) fn to_json(s: &Series) -> PolarsResult { let ca = s.struct_()?; + let dtype = ca.dtype().to_arrow(true); - let iter = ca - .chunks() - .iter() - .map(|arr| polars_json::json::write::serialize_to_utf8(arr.as_ref())); + let iter = ca.chunks().iter().map(|arr| { + let arr = arrow::compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap(); + polars_json::json::write::serialize_to_utf8(arr.as_ref()) + }); Ok(StringChunked::from_chunk_iter(ca.name(), iter).into_series()) } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 74595a134612..07c01ec53be5 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -287,10 +287,7 @@ pub fn min_horizontal>(exprs: E) -> PolarsResult { }) } -/// Create a new column with the sum of the values in each row. -/// -/// The name of the resulting column will be `"sum"`. -/// Use [`alias`](Expr::alias) to choose a different name. +/// Sum all values horizontally across columns. pub fn sum_horizontal>(exprs: E) -> PolarsResult { let exprs = exprs.as_ref().to_vec(); polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); @@ -303,7 +300,24 @@ pub fn sum_horizontal>(exprs: E) -> PolarsResult { input_wildcard_expansion: true, returns_scalar: false, cast_to_supertypes: false, - allow_rename: true, + ..Default::default() + }, + }) +} + +/// Compute the mean of all values horizontally across columns. +pub fn mean_horizontal>(exprs: E) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + + Ok(Expr::Function { + input: exprs, + function: FunctionExpr::MeanHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + input_wildcard_expansion: true, + returns_scalar: false, + cast_to_supertypes: false, ..Default::default() }, }) diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 9f0219e566a8..603ec2553590 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -5,7 +5,6 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; -use crate::dsl::function_expr::FunctionExpr; use crate::prelude::function_expr::ListFunction; use crate::prelude::*; @@ -107,6 +106,21 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::Mean)) } + pub fn median(self) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Median)) + } + + pub fn std(self, ddof: u8) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Std(ddof))) + } + + pub fn var(self, ddof: u8) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Var(ddof))) + } + /// Sort every sublist. pub fn sort(self, options: SortOptions) -> Expr { self.0 @@ -131,6 +145,11 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::Unique(true))) } + pub fn n_unique(self) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::NUnique)) + } + /// Get items in every sublist by index. pub fn get(self, index: Expr) -> Expr { self.0.map_many_private( @@ -147,7 +166,7 @@ impl ListNameSpace { /// - `null_on_oob`: Return a null when an index is out of bounds. /// This behavior is more expensive than defaulting to returning an `Error`. #[cfg(feature = "list_gather")] - pub fn take(self, index: Expr, null_on_oob: bool) -> Expr { + pub fn gather(self, index: Expr, null_on_oob: bool) -> Expr { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Gather(null_on_oob)), &[index], @@ -156,6 +175,16 @@ impl ListNameSpace { ) } + #[cfg(feature = "list_gather")] + pub fn gather_every(self, n: Expr, offset: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::GatherEvery), + &[n, offset], + false, + false, + ) + } + /// Get first item of every sublist. pub fn first(self) -> Expr { self.get(lit(0i64)) diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index 6000864dee70..28a554007a50 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -2,7 +2,6 @@ use std::fmt::Display; use std::ops::BitAnd; use super::*; -use crate::dsl::selector::Selector; use crate::logical_plan::projection::is_regex_projection; use crate::logical_plan::tree_format::TreeFmtVisitor; use crate::logical_plan::visitor::{AexprNode, TreeWalker}; @@ -137,7 +136,7 @@ impl MetaNameSpace { pub fn into_tree_formatter(self) -> PolarsResult { let mut arena = Default::default(); let node = to_aexpr(self.0, &mut arena); - let mut visitor = TreeFmtVisitor::new(); + let mut visitor = TreeFmtVisitor::default(); AexprNode::with_context(node, &mut arena, |ae_node| ae_node.visit(&mut visitor))?; Ok(visitor) } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index b17a64452f70..157562bf6e8b 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -8,6 +8,8 @@ use std::any::Any; #[cfg(feature = "dtype-categorical")] pub use cat::*; +#[cfg(feature = "rolling_window")] +pub(crate) use polars_time::prelude::*; mod arithmetic; mod arity; #[cfg(feature = "dtype-array")] @@ -52,14 +54,13 @@ pub use functions::*; pub use list::*; #[cfg(feature = "meta")] pub use meta::*; +pub use name::*; pub use options::*; use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; use polars_core::utils::try_get_supertype; -#[cfg(feature = "rolling_window")] -use polars_time::prelude::SeriesOpsTime; pub(crate) use selector::Selector; #[cfg(feature = "dtype-struct")] pub use struct_::*; @@ -68,9 +69,6 @@ pub use udf::UserDefinedFunction; use crate::constants::MAP_LIST_NAME; pub use crate::logical_plan::lit; use crate::prelude::*; -use crate::utils::has_expr; -#[cfg(feature = "is_in")] -use crate::utils::has_leaf_literal; impl Expr { /// Modify the Options passed to the `Function` node. @@ -1631,6 +1629,10 @@ impl Expr { self.map_private(FunctionExpr::Reinterpret(signed)) } + pub fn extend_constant(self, value: Expr, n: Expr) -> Expr { + self.apply_many_private(FunctionExpr::ExtendConstant, &[value, n], false, false) + } + #[cfg(feature = "strings")] /// Get the [`string::StringNameSpace`] pub fn str(self) -> string::StringNameSpace { diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs index 68d38f433917..8034705cacfe 100644 --- a/crates/polars-plan/src/dsl/name.rs +++ b/crates/polars-plan/src/dsl/name.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "dtype-struct")] +use smartstring::alias::String as SmartString; + use super::*; /// Specialized expressions for modifying the name of existing expressions. @@ -56,4 +59,53 @@ impl ExprNameNameSpace { pub fn to_uppercase(self) -> Expr { self.map(move |name| Ok(name.to_uppercase())) } + + #[cfg(feature = "dtype-struct")] + pub fn map_fields(self, function: FieldsNameMapper) -> Expr { + let f = function.clone(); + self.0.map( + move |s| { + let s = s.struct_()?; + let fields = s + .fields() + .iter() + .map(|fd| { + let mut fd = fd.clone(); + fd.rename(&function(fd.name())); + fd + }) + .collect::>(); + StructChunked::new(s.name(), &fields).map(|ca| Some(ca.into_series())) + }, + GetOutput::map_dtype(move |dt| match dt { + DataType::Struct(fds) => { + let fields = fds + .iter() + .map(|fd| Field::new(&f(fd.name()), fd.data_type().clone())) + .collect(); + DataType::Struct(fields) + }, + _ => panic!("Only struct dtype is supported for `map_fields`."), + }), + ) + } + + #[cfg(feature = "dtype-struct")] + pub fn prefix_fields(self, prefix: &str) -> Expr { + self.0 + .map_private(FunctionExpr::StructExpr(StructFunction::PrefixFields( + Arc::from(prefix), + ))) + } + + #[cfg(feature = "dtype-struct")] + pub fn suffix_fields(self, suffix: &str) -> Expr { + self.0 + .map_private(FunctionExpr::StructExpr(StructFunction::SuffixFields( + Arc::from(suffix), + ))) + } } + +#[cfg(feature = "dtype-struct")] +pub type FieldsNameMapper = Arc SmartString + Send + Sync>; diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index b15e1fbe9b5d..e1fa05d419f0 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -1,7 +1,6 @@ use std::io::Cursor; use std::sync::Arc; -use arrow::legacy::error::PolarsResult; use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::DataFrame; diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 42a7cb2471fd..88c43c4e5ff7 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -1,4 +1,3 @@ -use super::function_expr::StringFunction; use super::*; /// Specialized expressions for [`Series`] of [`DataType::String`]. pub struct StringNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs index eb6c066dca3f..db02ff023045 100644 --- a/crates/polars-plan/src/dsl/struct_.rs +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -1,5 +1,4 @@ use super::*; -use crate::dsl::function_expr::StructFunction; /// Specialized expressions for Struct dtypes. pub struct StructNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 704d3fc1d1c7..300ed2b9472b 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -2,16 +2,12 @@ mod hash; mod schema; use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use arrow::legacy::prelude::QuantileInterpolOptions; -use polars_core::frame::group_by::GroupByMethod; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; use strum_macros::IntoStaticStr; -use crate::dsl::function_expr::FunctionExpr; #[cfg(feature = "cse")] use crate::logical_plan::visitor::AexprNode; use crate::logical_plan::Context; @@ -196,7 +192,7 @@ impl AExpr { #[cfg(feature = "cse")] pub(crate) fn is_equal(l: Node, r: Node, arena: &Arena) -> bool { let arena = arena as *const Arena as *mut Arena; - // safety: we can pass a *mut pointer + // SAFETY: we can pass a *mut pointer // the equality operation will not access mutable unsafe { let ae_node_l = AexprNode::from_raw(l, arena); @@ -255,38 +251,38 @@ impl AExpr { } /// Push nodes at this level to a pre-allocated stack - pub(crate) fn nodes(&self, container: &mut Vec) { + pub(crate) fn nodes(&self, container: &mut C) { use AExpr::*; match self { Nth(_) | Column(_) | Literal(_) | Wildcard | Len => {}, - Alias(e, _) => container.push(*e), + Alias(e, _) => container.push_node(*e), BinaryExpr { left, op: _, right } => { // reverse order so that left is popped first - container.push(*right); - container.push(*left); + container.push_node(*right); + container.push_node(*left); }, - Cast { expr, .. } => container.push(*expr), - Sort { expr, .. } => container.push(*expr), + Cast { expr, .. } => container.push_node(*expr), + Sort { expr, .. } => container.push_node(*expr), Gather { expr, idx, .. } => { - container.push(*idx); + container.push_node(*idx); // latest, so that it is popped first - container.push(*expr); + container.push_node(*expr); }, SortBy { expr, by, .. } => { for node in by { - container.push(*node) + container.push_node(*node) } // latest, so that it is popped first - container.push(*expr); + container.push_node(*expr); }, Filter { input, by } => { - container.push(*by); + container.push_node(*by); // latest, so that it is popped first - container.push(*input); + container.push_node(*input); }, Agg(agg_e) => match agg_e.get_input() { - NodeInputs::Single(node) => container.push(node), + NodeInputs::Single(node) => container.push_node(node), NodeInputs::Many(nodes) => container.extend_from_slice(&nodes), NodeInputs::Leaf => {}, }, @@ -295,10 +291,10 @@ impl AExpr { falsy, predicate, } => { - container.push(*predicate); - container.push(*falsy); + container.push_node(*predicate); + container.push_node(*falsy); // latest, so that it is popped first - container.push(*truthy); + container.push_node(*truthy); }, AnonymousFunction { input, .. } | Function { input, .. } => // we iterate in reverse order, so that the lhs is popped first and will be found @@ -308,29 +304,29 @@ impl AExpr { .iter() .rev() .copied() - .for_each(|node| container.push(node)) + .for_each(|node| container.push_node(node)) }, - Explode(e) => container.push(*e), + Explode(e) => container.push_node(*e), Window { function, partition_by, options: _, } => { for e in partition_by.iter().rev() { - container.push(*e); + container.push_node(*e); } // latest so that it is popped first - container.push(*function); + container.push_node(*function); }, Slice { input, offset, length, } => { - container.push(*length); - container.push(*offset); + container.push_node(*length); + container.push_node(*offset); // latest so that it is popped first - container.push(*input); + container.push_node(*input); }, } } diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp.rs index 45dcd2624c2e..547be5f05592 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp.rs @@ -1,16 +1,12 @@ use std::borrow::Cow; use std::path::PathBuf; -use std::sync::Arc; use polars_core::prelude::*; -use polars_utils::arena::{Arena, Node}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; use super::projection_expr::*; -use crate::logical_plan::functions::FunctionNode; -use crate::logical_plan::schema::FileInfo; -use crate::logical_plan::FileScan; use crate::prelude::*; -use crate::utils::PushNode; /// [`ALogicalPlan`] is a representation of [`LogicalPlan`] with [`Node`]s which are allocated in an [`Arena`] #[derive(Clone, Debug)] @@ -474,9 +470,9 @@ impl ALogicalPlan { feature = "fused" ))] pub(crate) fn get_input(&self) -> Option { - let mut inputs = [None, None]; + let mut inputs: UnitVec = unitvec!(); self.copy_inputs(&mut inputs); - inputs[0] + inputs.first().copied() } } diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 00ea6ddd2000..12af640e2d68 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -1,7 +1,6 @@ #[cfg(feature = "csv")] use std::io::{Read, Seek}; -use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; #[cfg(feature = "parquet")] use polars_io::cloud::CloudOptions; @@ -31,9 +30,7 @@ use polars_io::{ use super::builder_functions::*; use crate::dsl::functions::horizontal::all_horizontal; -use crate::logical_plan::functions::FunctionNode; use crate::logical_plan::projection::{is_regex_projection, rewrite_projections}; -use crate::logical_plan::schema::{det_join_schema, FileInfo}; #[cfg(feature = "python")] use crate::prelude::python_udf::PythonFunction; use crate::prelude::*; @@ -59,18 +56,23 @@ fn format_err(msg: &str, input: &LogicalPlan) -> String { format!("{msg}\n\nError originated just after this operation:\n{input:?}") } -/// Returns every error or msg: &str as `ComputeError`. -/// It also shows the logical plan node where the error -/// originated. +/// Returns every error or msg: &str as `ComputeError`. It also shows the logical plan node where the error originated. +/// If `input` is already a `LogicalPlan::Error`, then return it as is; errors already keep track of their previous +/// inputs, so we don't have to do it again here. macro_rules! raise_err { ($err:expr, $input:expr, $convert:ident) => {{ - let format_err_outer = |msg: &str| format_err(msg, &$input); - - let err = $err.wrap_msg(&format_err_outer); - - LogicalPlan::Error { - input: Box::new($input.clone()), - err: err.into(), + let input: LogicalPlan = $input.clone(); + match &input { + LogicalPlan::Error { .. } => input, + _ => { + let format_err_outer = |msg: &str| format_err(msg, &input); + let err = $err.wrap_msg(&format_err_outer); + + LogicalPlan::Error { + input: Box::new(input), + err: err.into(), // PolarsError -> ErrorState + } + }, } .$convert() }}; @@ -298,6 +300,7 @@ impl LogicalPlanBuilder { try_parse_dates: bool, raise_if_empty: bool, truncate_ragged_lines: bool, + mut n_threads: Option, ) -> PolarsResult { let path = path.into(); let mut file = polars_utils::open_file(&path)?; @@ -336,6 +339,7 @@ impl LogicalPlanBuilder { null_values.as_ref(), try_parse_dates, raise_if_empty, + &mut n_threads, )?; if let Some(rc) = &row_index { @@ -390,6 +394,7 @@ impl LogicalPlanBuilder { try_parse_dates, raise_if_empty, truncate_ragged_lines, + n_threads, }, }, } @@ -426,7 +431,7 @@ impl LogicalPlanBuilder { if columns.is_empty() { self.map( - |_| Ok(DataFrame::new_no_checks(vec![])), + |_| Ok(DataFrame::empty()), AllowedOptimizations::default(), Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))), "EMPTY PROJECTION", @@ -451,7 +456,7 @@ impl LogicalPlanBuilder { if exprs.is_empty() { self.map( - |_| Ok(DataFrame::new_no_checks(vec![])), + |_| Ok(DataFrame::empty()), AllowedOptimizations::default(), Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))), "EMPTY PROJECTION", @@ -536,7 +541,10 @@ impl LogicalPlanBuilder { if !output_names.insert(field.name().clone()) { let msg = format!( - "The name: '{}' passed to `LazyFrame.with_columns` is duplicate", + "the name: '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ + It's possible that multiple expressions are returning the same default column name. \ + If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \ + duplicate column names.", field.name() ); return raise_err!(polars_err!(ComputeError: msg), &self.0, into); diff --git a/crates/polars-plan/src/logical_plan/builder_functions.rs b/crates/polars-plan/src/logical_plan/builder_functions.rs index d74994d10b0e..5d207abd42b3 100644 --- a/crates/polars-plan/src/logical_plan/builder_functions.rs +++ b/crates/polars-plan/src/logical_plan/builder_functions.rs @@ -35,29 +35,23 @@ pub(super) fn det_melt_schema(args: &MeltArgs, input_schema: &Schema) -> SchemaR new_schema.with_column(variable_name, DataType::String); // We need to determine the supertype of all value columns. - let mut st = None; + let mut supertype = DataType::Null; // take all columns that are not in `id_vars` as `value_var` if args.value_vars.is_empty() { let id_vars = PlHashSet::from_iter(&args.id_vars); for (name, dtype) in input_schema.iter() { if !id_vars.contains(name) { - match &st { - None => st = Some(dtype.clone()), - Some(st_) => st = Some(try_get_supertype(st_, dtype).unwrap()), - } + supertype = try_get_supertype(&supertype, dtype).unwrap(); } } } else { for name in &args.value_vars { let dtype = input_schema.get(name).unwrap(); - match &st { - None => st = Some(dtype.clone()), - Some(st_) => st = Some(try_get_supertype(st_, dtype).unwrap()), - } + supertype = try_get_supertype(&supertype, dtype).unwrap(); } } - new_schema.with_column(value_name, st.unwrap()); + new_schema.with_column(value_name, supertype); Arc::new(new_schema) } diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 413818c46478..91c384aa86bb 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -29,7 +29,7 @@ fn write_scan( )), }; - write!(f, "{:indent$}{} SCAN {}", "", name, path_fmt)?; + write!(f, "{:indent$}{name} SCAN {path_fmt}", "")?; if n_columns > 0 { write!( f, @@ -37,7 +37,7 @@ fn write_scan( "", )?; } else { - write!(f, "\n{:indent$}PROJECT */{total_columns} COLUMNS", "",)?; + write!(f, "\n{:indent$}PROJECT */{total_columns} COLUMNS", "")?; } if let Some(predicate) = predicate { write!(f, "\n{:indent$}SELECTION: {predicate}", "")?; @@ -79,7 +79,7 @@ impl LogicalPlan { Union { inputs, options } => { let mut name = String::new(); let name = if let Some(slice) = options.slice { - write!(name, "SLICED UNION: {:?}", slice)?; + write!(name, "SLICED UNION: {slice:?}")?; name.as_str() } else { "UNION" @@ -89,12 +89,12 @@ impl LogicalPlan { // - 1 => PLAN 0, PLAN 1, ... PLAN N // - 2 => actual formatting of plans let sub_sub_indent = sub_indent + 2; - write!(f, "{:indent$}{}", "", name)?; + write!(f, "{:indent$}{name}", "")?; for (i, plan) in inputs.iter().enumerate() { write!(f, "\n{:sub_indent$}PLAN {i}:", "")?; plan._format(f, sub_sub_indent)?; } - write!(f, "\n{:indent$}END {}", "", name) + write!(f, "\n{:indent$}END {name}", "") }, HConcat { inputs, .. } => { let sub_sub_indent = sub_indent + 2; @@ -194,7 +194,7 @@ impl LogicalPlan { input_left._format(f, sub_indent)?; write!(f, "\n{:indent$}RIGHT PLAN ON: {right_on:?}", "")?; input_right._format(f, sub_indent)?; - write!(f, "\n{:indent$}END {} JOIN", "", how) + write!(f, "\n{:indent$}END {how} JOIN", "") }, HStack { input, exprs, .. } => { write!(f, "{:indent$} WITH_COLUMNS:", "",)?; @@ -202,7 +202,11 @@ impl LogicalPlan { input._format(f, sub_indent) }, Distinct { input, options } => { - write!(f, "{:indent$}UNIQUE BY {:?}", "", options.subset)?; + write!( + f, + "{:indent$}UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + "", options.maintain_order, options.keep_strategy, options.subset + )?; input._format(f, sub_indent) }, Slice { input, offset, len } => { @@ -216,7 +220,7 @@ impl LogicalPlan { write!(f, "{:indent$}{function_fmt}", "")?; input._format(f, sub_indent) }, - Error { input, err } => write!(f, "{err:?}\n{input:?}"), + Error { err, .. } => write!(f, "{err:?}"), ExtContext { input, .. } => { write!(f, "{:indent$}EXTERNAL_CONTEXT", "")?; input._format(f, sub_indent) @@ -228,7 +232,7 @@ impl LogicalPlan { #[cfg(feature = "cloud")] SinkType::Cloud { .. } => "SINK (cloud)", }; - write!(f, "{:indent$}{}", "", name)?; + write!(f, "{:indent$}{name}", "")?; input._format(f, sub_indent) }, } @@ -285,9 +289,12 @@ impl Debug for Expr { } }, BinaryExpr { left, op, right } => write!(f, "[({left:?}) {op:?} ({right:?})]"), - Sort { expr, options } => match options.descending { - true => write!(f, "{expr:?}.sort(desc)"), - false => write!(f, "{expr:?}.sort(asc)"), + Sort { expr, options } => { + if options.descending { + write!(f, "{expr:?}.sort(desc)") + } else { + write!(f, "{expr:?}.sort(asc)") + } }, SortBy { expr, @@ -373,7 +380,7 @@ impl Debug for Expr { input, function, .. } => { if input.len() >= 2 { - write!(f, "{:?}.{}({:?})", input[0], function, &input[1..]) + write!(f, "{:?}.{function}({:?})", input[0], &input[1..]) } else { write!(f, "{:?}.{function}()", input[0]) } @@ -423,7 +430,7 @@ impl Debug for LiteralValue { } }, _ => { - let av = self.to_anyvalue().unwrap(); + let av = self.to_any_value().unwrap(); write!(f, "{av}") }, } diff --git a/crates/polars-plan/src/logical_plan/functions/count.rs b/crates/polars-plan/src/logical_plan/functions/count.rs new file mode 100644 index 000000000000..b84322fbb814 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/functions/count.rs @@ -0,0 +1,105 @@ +#[cfg(feature = "parquet")] +use arrow::io::ipc::read::get_row_count as count_rows_ipc; +#[cfg(feature = "parquet")] +use polars_io::cloud::CloudOptions; +#[cfg(feature = "csv")] +use polars_io::csv::count_rows as count_rows_csv; +#[cfg(all(feature = "parquet", feature = "cloud"))] +use polars_io::parquet::ParquetAsyncReader; +#[cfg(feature = "parquet")] +use polars_io::parquet::ParquetReader; +#[cfg(all(feature = "parquet", feature = "async"))] +use polars_io::pl_async::{get_runtime, with_concurrency_budget}; +#[cfg(feature = "parquet")] +use polars_io::{is_cloud_url, SerReader}; + +use super::*; + +#[allow(unused_variables)] +pub fn count_rows(paths: &Arc<[PathBuf]>, scan_type: &FileScan) -> PolarsResult { + match scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { options } => { + let n_rows: PolarsResult = paths + .iter() + .map(|path| { + count_rows_csv( + path, + options.separator, + options.quote_char, + options.comment_prefix.as_ref(), + options.eol_char, + options.has_header, + ) + }) + .sum(); + Ok(DataFrame::new(vec![Series::new("len", [n_rows? as IdxSize])]).unwrap()) + }, + #[cfg(feature = "parquet")] + FileScan::Parquet { cloud_options, .. } => { + let n_rows = count_rows_parquet(paths, cloud_options.as_ref())?; + Ok(DataFrame::new(vec![Series::new("len", [n_rows as IdxSize])]).unwrap()) + }, + #[cfg(feature = "ipc")] + FileScan::Ipc { options } => { + let n_rows: PolarsResult = paths + .iter() + .map(|path| { + let mut reader = polars_utils::open_file(path)?; + count_rows_ipc(&mut reader) + }) + .sum(); + Ok(DataFrame::new(vec![Series::new("len", [n_rows? as IdxSize])]).unwrap()) + }, + FileScan::Anonymous { .. } => { + unreachable!(); + }, + } +} +#[cfg(feature = "parquet")] +pub(super) fn count_rows_parquet( + paths: &Arc<[PathBuf]>, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + if paths.is_empty() { + return Ok(0); + }; + let is_cloud = is_cloud_url(paths.first().unwrap().as_path()); + + if is_cloud { + #[cfg(not(feature = "cloud"))] + panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); + + #[cfg(feature = "cloud")] + { + get_runtime().block_on(count_rows_cloud_parquet(paths, cloud_options)) + } + } else { + paths + .iter() + .map(|path| { + let file = polars_utils::open_file(path)?; + let mut reader = ParquetReader::new(file); + reader.num_rows() + }) + .sum::>() + } +} + +#[cfg(all(feature = "parquet", feature = "async"))] +async fn count_rows_cloud_parquet( + paths: &Arc<[PathBuf]>, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + let collection = paths.iter().map(|path| { + with_concurrency_budget(1, || async { + let mut reader = + ParquetAsyncReader::from_uri(&path.to_string_lossy(), cloud_options, None, None) + .await?; + reader.num_rows().await + }) + }); + futures::future::try_join_all(collection) + .await + .map(|rows| rows.iter().sum()) +} diff --git a/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs b/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs index 4e541c6c169d..a20a85d68812 100644 --- a/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs +++ b/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs @@ -2,7 +2,7 @@ use polars_core::prelude::*; use polars_ops::prelude::*; pub(super) fn merge_sorted(df: &DataFrame, column: &str) -> PolarsResult { - // Safety: + // SAFETY: // the dtype is known let (left_cols, right_cols) = unsafe { ( @@ -29,8 +29,8 @@ pub(super) fn merge_sorted(df: &DataFrame, column: &str) -> PolarsResult, + scan_type: FileScan, + }, #[cfg_attr(feature = "serde", serde(skip))] Pipeline { function: Arc, @@ -111,6 +114,7 @@ impl PartialEq for FunctionNode { ) => l == r && dl == dr, (DropNulls { subset: l }, DropNulls { subset: r }) => l == r, (Rechunk, Rechunk) => true, + (Count { paths: paths_l, .. }, Count { paths: paths_r, .. }) => paths_l == paths_r, ( Rename { existing: existing_l, @@ -141,6 +145,7 @@ impl FunctionNode { MergeSorted { .. } => false, DropNulls { .. } | FastProjection { .. } + | Count { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, @@ -193,6 +198,11 @@ impl FunctionNode { Ok(Cow::Owned(Arc::new(schema))) }, DropNulls { .. } => Ok(Cow::Borrowed(input_schema)), + Count { .. } => { + let mut schema: Schema = Schema::with_capacity(1); + schema.insert_at_index(0, SmartString::from("len"), IDX_DTYPE)?; + Ok(Cow::Owned(Arc::new(schema))) + }, Rechunk => Ok(Cow::Borrowed(input_schema)), Unnest { columns: _columns } => { #[cfg(feature = "dtype-struct")] @@ -254,7 +264,7 @@ impl FunctionNode { | Melt { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, - RowIndex { .. } => false, + RowIndex { .. } | Count { .. } => false, Pipeline { .. } => unimplemented!(), } } @@ -268,6 +278,7 @@ impl FunctionNode { FastProjection { .. } | DropNulls { .. } | Rechunk + | Count { .. } | Unnest { .. } | Rename { .. } | Explode { .. } @@ -312,6 +323,7 @@ impl FunctionNode { } }, DropNulls { subset } => df.drop_nulls(Some(subset.as_ref())), + Count { paths, scan_type } => count::count_rows(paths, scan_type), Rechunk => { df.as_single_chunk_par(); Ok(df) @@ -376,6 +388,7 @@ impl Display for FunctionNode { fmt_column_delimited(f, subset, "[", "]") }, Rechunk => write!(f, "RECHUNK"), + Count { .. } => write!(f, "FAST COUNT(*)"), Unnest { columns } => { write!(f, "UNNEST by:")?; let columns = columns.as_ref(); diff --git a/crates/polars-plan/src/logical_plan/hive.rs b/crates/polars-plan/src/logical_plan/hive.rs index dc71758e3581..46f4d5a50722 100644 --- a/crates/polars-plan/src/logical_plan/hive.rs +++ b/crates/polars-plan/src/logical_plan/hive.rs @@ -38,15 +38,30 @@ impl HivePartitions { pub(crate) fn parse_url(url: &Path) -> Option { let sep = separator(url); - let partitions = url - .display() - .to_string() - .split(sep) - .filter_map(|part| { + let url_string = url.display().to_string(); + + let pre_filt = url_string.split(sep); + + let split_count_m1 = pre_filt.clone().count() - 1; + + let partitions = pre_filt + .enumerate() + .filter_map(|(index, part)| { let mut it = part.split('='); let name = it.next()?; let value = it.next()?; + // Don't see files `foo=1.parquet` as hive partitions. + // So we return globs and paths with extensions. + if value.contains('*') { + return None; + } + + // Identify file by index location + if index == split_count_m1 { + return None; + } + // Having multiple '=' doesn't seem like valid hive partition, // continue as url. if it.next().is_some() { @@ -63,7 +78,7 @@ impl HivePartitions { let value = value.parse::().ok()?; Series::new(name, &[value]) } else if value == "__HIVE_DEFAULT_PARTITION__" { - Series::full_null(name, 1, &DataType::Null) + Series::new_null(name, 1) } else { Series::new(name, &[percent_decode_str(value).decode_utf8().ok()?]) }; @@ -81,6 +96,7 @@ impl HivePartitions { .into_iter() .map(ColumnStats::from_column_literal) .collect(), + None, ); Some(HivePartitions { stats }) diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index e83952476d82..611f08badd83 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -1,4 +1,6 @@ use arrow::legacy::error::PolarsResult; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; use crate::prelude::*; @@ -99,14 +101,13 @@ macro_rules! push_expr { impl Expr { /// Expr::mutate().apply(fn()) pub fn mutate(&mut self) -> ExprMut { - let mut stack = Vec::with_capacity(4); - stack.push(self); + let stack = unitvec!(self); ExprMut { stack } } } pub struct ExprMut<'a> { - stack: Vec<&'a mut Expr>, + stack: UnitVec<&'a mut Expr>, } impl<'a> ExprMut<'a> { @@ -139,7 +140,7 @@ impl<'a> ExprMut<'a> { } pub struct ExprIter<'a> { - stack: Vec<&'a Expr>, + stack: UnitVec<&'a Expr>, } impl<'a> Iterator for ExprIter<'a> { @@ -154,12 +155,12 @@ impl<'a> Iterator for ExprIter<'a> { } impl Expr { - pub fn nodes<'a>(&'a self, container: &mut Vec<&'a Expr>) { + pub fn nodes<'a>(&'a self, container: &mut UnitVec<&'a Expr>) { let mut push = |e: &'a Expr| container.push(e); push_expr!(self, push, iter); } - pub fn nodes_mut<'a>(&'a mut self, container: &mut Vec<&'a mut Expr>) { + pub fn nodes_mut<'a>(&'a mut self, container: &mut UnitVec<&'a mut Expr>) { let mut push = |e: &'a mut Expr| container.push(e); push_expr!(self, push, iter_mut); } @@ -170,14 +171,13 @@ impl<'a> IntoIterator for &'a Expr { type IntoIter = ExprIter<'a>; fn into_iter(self) -> Self::IntoIter { - let mut stack = Vec::with_capacity(4); - stack.push(self); + let stack = unitvec!(self); ExprIter { stack } } } pub struct AExprIter<'a> { - stack: Vec, + stack: UnitVec, arena: Option<&'a Arena>, } @@ -203,8 +203,7 @@ pub trait ArenaExprIter<'a> { impl<'a> ArenaExprIter<'a> for &'a Arena { fn iter(&self, root: Node) -> AExprIter<'a> { - let mut stack = Vec::with_capacity(4); - stack.push(root); + let stack = unitvec![root]; AExprIter { stack, arena: Some(self), @@ -213,7 +212,7 @@ impl<'a> ArenaExprIter<'a> for &'a Arena { } pub struct AlpIter<'a> { - stack: Vec, + stack: UnitVec, arena: &'a Arena, } @@ -223,7 +222,7 @@ pub trait ArenaLpIter<'a> { impl<'a> ArenaLpIter<'a> for &'a Arena { fn iter(&self, root: Node) -> AlpIter<'a> { - let stack = vec![root]; + let stack = unitvec![root]; AlpIter { stack, arena: self } } } diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 057c3f36539c..4965cd2c7d99 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -71,7 +71,7 @@ impl LiteralValue { } } - pub fn to_anyvalue(&self) -> Option { + pub fn to_any_value(&self) -> Option { use LiteralValue::*; let av = match self { Null => AnyValue::Null, @@ -336,7 +336,7 @@ impl Hash for LiteralValue { data_type.hash(state) }, _ => { - if let Some(v) = self.to_anyvalue() { + if let Some(v) = self.to_any_value() { v.hash_impl(state, true) } }, diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 08e9816b900a..f378770456ad 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -34,7 +34,6 @@ mod projection_expr; #[cfg(feature = "python")] mod pyarrow; mod schema; -#[cfg(any(feature = "meta", feature = "cse"))] pub(crate) mod tree_format; pub mod visitor; @@ -55,6 +54,7 @@ pub use schema::*; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; +use self::tree_format::{TreeFmtNode, TreeFmtVisitor}; #[cfg(any( feature = "ipc", feature = "parquet", @@ -75,72 +75,64 @@ pub enum Context { } #[derive(Debug)] -pub enum ErrorState { - NotYetEncountered { err: PolarsError }, - AlreadyEncountered { prev_err_msg: String }, -} - -impl std::fmt::Display for ErrorState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ErrorState::NotYetEncountered { err } => write!(f, "NotYetEncountered({err})")?, - ErrorState::AlreadyEncountered { prev_err_msg } => { - write!(f, "AlreadyEncountered({prev_err_msg})")? - }, - }; - - Ok(()) - } +pub(crate) struct ErrorStateUnsync { + n_times: usize, + err: PolarsError, } #[derive(Clone)] -pub struct ErrorStateSync(Arc>); - -impl std::ops::Deref for ErrorStateSync { - type Target = Arc>; +pub struct ErrorState(pub(crate) Arc>); - fn deref(&self) -> &Self::Target { - &self.0 +impl std::fmt::Debug for ErrorState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let this = self.0.lock().unwrap(); + // Skip over the Arc> and just print the fields we care + // about. Technically this is misleading, but the insides of ErrorState are not + // public, so this only affects authors of polars, not users (and the odds that + // this affects authors is slim) + f.debug_struct("ErrorState") + .field("n_times", &this.n_times) + .field("err", &this.err) + .finish() } } -impl std::fmt::Debug for ErrorStateSync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ErrorStateSync({})", &*self.0.lock().unwrap()) +impl From for ErrorState { + fn from(err: PolarsError) -> Self { + Self(Arc::new(Mutex::new(ErrorStateUnsync { n_times: 0, err }))) } } -impl ErrorStateSync { +impl ErrorState { fn take(&self) -> PolarsError { - let mut curr_err = self.0.lock().unwrap(); + let mut this = self.0.lock().unwrap(); - match &*curr_err { - ErrorState::NotYetEncountered { err: polars_err } => { - // Need to finish using `polars_err` here so that NLL considers `err` dropped - let prev_err_msg = polars_err.to_string(); - // Place AlreadyEncountered in `self` for future users of `self` - let prev_err = std::mem::replace( - &mut *curr_err, - ErrorState::AlreadyEncountered { prev_err_msg }, - ); - // Since we're in this branch, we know err was a NotYetEncountered - match prev_err { - ErrorState::NotYetEncountered { err } => err, - ErrorState::AlreadyEncountered { .. } => unreachable!(), - } - }, - ErrorState::AlreadyEncountered { prev_err_msg } => { - polars_err!( - ComputeError: "LogicalPlan already failed with error: '{}'", prev_err_msg, + let ret_err = if this.n_times == 0 { + this.err.wrap_msg(&|msg| msg.to_owned()) + } else { + this.err.wrap_msg(&|msg| { + let n_times = this.n_times; + + let plural_s; + let was_were; + + if n_times == 1 { + plural_s = ""; + was_were = "was" + } else { + plural_s = "s"; + was_were = "were"; + }; + format!( + "{msg}\n\nLogicalPlan had already failed with the above error; \ + after failure, {n_times} additional operation{plural_s} \ + {was_were} attempted on the LazyFrame", ) - }, - } - } -} + }) + }; + this.n_times += 1; -impl From for ErrorStateSync { - fn from(err: PolarsError) -> Self { - Self(Arc::new(Mutex::new(ErrorState::NotYetEncountered { err }))) + ret_err } } @@ -248,7 +240,7 @@ pub enum LogicalPlan { #[cfg_attr(feature = "serde", serde(skip))] Error { input: Box, - err: ErrorStateSync, + err: ErrorState, }, /// This allows expressions to access other tables ExtContext { @@ -281,6 +273,12 @@ impl LogicalPlan { format!("{self:#?}") } + pub fn describe_tree_format(&self) -> String { + let mut visitor = TreeFmtVisitor::default(); + TreeFmtNode::root_logical_plan(self).traverse(&mut visitor); + format!("{visitor:#?}") + } + pub fn to_alp(self) -> PolarsResult<(Node, Arena, Arena)> { let mut lp_arena = Arena::with_capacity(16); let mut expr_arena = Arena::with_capacity(16); diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index 23881fdd4b3c..a7be49238fe9 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::sync::Arc; use super::*; diff --git a/crates/polars-plan/src/logical_plan/optimizer/count_star.rs b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs new file mode 100644 index 000000000000..fcbe1e61d762 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs @@ -0,0 +1,93 @@ +use std::path::PathBuf; + +use super::*; + +pub(super) struct CountStar; + +impl CountStar { + pub(super) fn new() -> Self { + Self + } +} + +impl OptimizationRule for CountStar { + // Replace select count(*) from datasource with specialized map function. + fn optimize_plan( + &mut self, + lp_arena: &mut Arena, + _expr_arena: &mut Arena, + node: Node, + ) -> Option { + let mut paths = Vec::new(); + visit_logical_plan_for_scan_paths(&mut paths, node, lp_arena).map(|scan_type| { + // MapFunction needs a leaf node, hence we create a dummy placeholder node + let placeholder = ALogicalPlan::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + projection: None, + selection: None, + }; + let placeholder_node = lp_arena.add(placeholder); + + let sliced_paths: Arc<[PathBuf]> = paths.into(); + + let alp = ALogicalPlan::MapFunction { + input: placeholder_node, + function: FunctionNode::Count { + paths: sliced_paths, + scan_type, + }, + }; + lp_arena.replace(node, alp.clone()); + alp + }) + } +} + +// Visit the logical plan and return the file paths / scan type +// Return None if query is not a simple COUNT(*) FROM SOURCE +fn visit_logical_plan_for_scan_paths( + all_paths: &mut Vec, + node: Node, + lp_arena: &Arena, +) -> Option { + match lp_arena.get(node) { + ALogicalPlan::Union { inputs, .. } => { + // Preallocate right amount in case of globbing + if all_paths.is_empty() { + let _ = std::mem::replace(all_paths, Vec::with_capacity(inputs.len())); + } + let mut scan_type = None; + for input in inputs { + match visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena) { + Some(leaf_scan_type) => { + match &scan_type { + None => scan_type = Some(leaf_scan_type), + Some(scan_type) => { + // All scans must be of the same type (e.g. csv / parquet) + if std::mem::discriminant(scan_type) + != std::mem::discriminant(&leaf_scan_type) + { + return None; + } + }, + }; + }, + None => return None, + } + } + scan_type + }, + ALogicalPlan::Scan { + scan_type, paths, .. + } if !matches!(scan_type, FileScan::Anonymous { .. }) => { + all_paths.extend(paths.iter().cloned()); + Some(scan_type.clone()) + }, + ALogicalPlan::Projection { input, .. } => { + visit_logical_plan_for_scan_paths(all_paths, *input, lp_arena) + }, + _ => None, + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs index 16dc208101dd..da3ac4ada8ab 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse.rs @@ -3,6 +3,8 @@ use std::collections::{BTreeMap, BTreeSet}; use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; use crate::prelude::*; @@ -73,9 +75,9 @@ pub(super) fn collect_trails( }, lp => { // other nodes have only a single input - let nodes = &mut [None]; - lp.copy_inputs(nodes); - if let Some(input) = nodes[0] { + let mut nodes: UnitVec = unitvec![]; + lp.copy_inputs(&mut nodes); + if let Some(input) = nodes.pop() { collect_trails(input, lp_arena, trails, id, collect)? } }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs index 3e7891bc1c66..d0cd86b2a966 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs @@ -3,8 +3,7 @@ use polars_utils::vec::CapacityByFactor; use super::*; use crate::constants::CSE_REPLACED; use crate::logical_plan::projection_expr::ProjectionExprs; -use crate::logical_plan::visitor::{RewriteRecursion, VisitRecursion}; -use crate::prelude::visitor::{ALogicalPlanNode, AexprNode, RewritingVisitor, TreeWalker, Visitor}; +use crate::prelude::visitor::AexprNode; // We use hashes to get an Identifier // but this is very hard to debug, so we also have a version that @@ -361,6 +360,9 @@ impl Visitor for ExprIdentifierVisitor<'_> { fn pre_visit(&mut self, node: &Self::Node) -> PolarsResult { if skip_pre_visit(node.to_aexpr(), self.is_group_by) { + // Still add to the stack so that a parent becomes invalidated. + self.visit_stack + .push(VisitRecord::SubExprId(Identifier::new(), false)); return Ok(VisitRecursion::Skip); } diff --git a/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs b/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs index 6c3705ca93ee..7fa81a4085e9 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs @@ -1,7 +1,5 @@ use std::collections::BTreeSet; -use polars_utils::arena::{Arena, Node}; - use super::*; #[derive(Default)] diff --git a/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs b/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs index 8683d2716a03..937750f94024 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs @@ -1,10 +1,5 @@ -use std::sync::Arc; - use super::*; -use crate::dsl::function_expr::FunctionExpr; -use crate::logical_plan::functions::FunctionNode; use crate::logical_plan::iterator::*; -use crate::utils::aexpr_to_leaf_names; /// If we realize that a predicate drops nulls on a subset /// we replace it with an explicit df.drop_nulls call, as this diff --git a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs index 14ba5f29ebbc..6f8e7f27a453 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs @@ -4,8 +4,6 @@ use polars_core::prelude::*; use smartstring::SmartString; use super::*; -use crate::logical_plan::alp::ALogicalPlan; -use crate::logical_plan::functions::FunctionNode; /// Projection in the physical plan is done by selecting an expression per thread. /// In case of many projections and columns this can be expensive when the expressions are simple diff --git a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs index 23791d3dd6b0..d3377a145229 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs @@ -1,10 +1,7 @@ use std::path::PathBuf; -use std::sync::Arc; -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; -use crate::logical_plan::ALogicalPlanBuilder; use crate::prelude::*; #[derive(Hash, Eq, PartialEq, Clone, Debug)] diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index 44448a61c334..226e1636c6da 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -1,4 +1,3 @@ -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use crate::prelude::*; @@ -10,6 +9,7 @@ mod delay_rechunk; mod drop_nulls; mod collect_members; +mod count_star; #[cfg(feature = "cse")] mod cse_expr; mod fast_projection; @@ -36,8 +36,6 @@ mod type_coercion; use delay_rechunk::DelayRechunk; use drop_nulls::ReplaceDropNulls; use fast_projection::FastProjectionAndCollapse; -#[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] -use file_caching::{find_column_union_and_fingerprints, FileCacher}; use polars_io::predicates::PhysicalIoExpr; pub use predicate_pushdown::PredicatePushDown; pub use projection_pushdown::ProjectionPushDown; @@ -48,6 +46,7 @@ pub use type_coercion::TypeCoercionRule; use self::flatten_union::FlattenUnionRule; pub use crate::frame::{AllowedOptimizations, OptState}; +use crate::logical_plan::optimizer::count_star::CountStar; #[cfg(feature = "cse")] use crate::logical_plan::optimizer::cse_expr::CommonSubExprOptimizer; use crate::logical_plan::optimizer::predicate_pushdown::HiveEval; @@ -141,6 +140,11 @@ pub fn optimize( if members.has_joins_or_unions && members.has_cache { cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed); } + + if projection_pushdown_opt.is_count_star { + let mut count_star_opt = CountStar::new(); + count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top); + } } if predicate_pushdown { diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index c9f08519ffb4..b3c83ca70ae6 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -11,7 +11,6 @@ use utils::*; use super::*; use crate::dsl::function_expr::FunctionExpr; -use crate::logical_plan::optimizer; use crate::prelude::optimizer::predicate_pushdown::group_by::process_group_by; use crate::prelude::optimizer::predicate_pushdown::join::process_join; use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; @@ -404,24 +403,15 @@ impl<'a> PredicatePushDown<'a> { input, options } => { - - if matches!(options.keep_strategy, UniqueKeepStrategy::Any | UniqueKeepStrategy::None) { - // currently the distinct operation only keeps the first occurrences. - // this may have influence on the pushed down predicates. If the pushed down predicates - // contain a binary expression (thus depending on values in multiple columns) - // the final result may differ if it is pushed down. - - let mut root_count = 0; - - // if this condition is called more than once, its a binary or ternary operation. - let condition = |_| { - if root_count == 0 { - root_count += 1; - false - } else { - true - } + if let Some(ref subset) = options.subset { + // Predicates on the subset can pass. + let subset = subset.clone(); + let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); + for name in subset.iter() { + names_set.insert(name.as_str()); }; + + let condition = |name: Arc| !names_set.contains(name.as_ref()); let local_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 3e734f115d76..49885b8e0b61 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -1,10 +1,7 @@ -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use super::keys::*; -use crate::logical_plan::Context; use crate::prelude::*; -use crate::utils::{aexpr_to_leaf_names, has_aexpr}; trait Dsl { fn and(self, right: Node, arena: &mut Arena) -> Node; @@ -115,8 +112,9 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) has_aexpr(node, expr_arena, matches) } -/// Transfer a predicate from `acc_predicates` that will be pushed down -/// to a local_predicates vec based on a condition. +/// Evaluates a condition on the column name inputs of every predicate, where if +/// the condition evaluates to true on any column name the predicate is +/// transferred to local. pub(super) fn transfer_to_local_by_name( expr_arena: &Arena, acc_predicates: &mut PlHashMap, Node>, @@ -132,7 +130,7 @@ where for name in root_names { if condition(name) { remove_keys.push(key.clone()); - continue; + break; } } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 519d521003c2..e10a13eb261f 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -16,7 +16,6 @@ use polars_io::RowIndex; use semi_anti_join::process_semi_anti_join; use crate::logical_plan::Context; -use crate::prelude::iterator::ArenaExprIter; use crate::prelude::optimizer::projection_pushdown::generic::process_generic; use crate::prelude::optimizer::projection_pushdown::group_by::process_group_by; use crate::prelude::optimizer::projection_pushdown::hconcat::process_hconcat; @@ -26,8 +25,7 @@ use crate::prelude::optimizer::projection_pushdown::projection::process_projecti use crate::prelude::optimizer::projection_pushdown::rename::process_rename; use crate::prelude::*; use crate::utils::{ - aexpr_assign_renamed_leaf, aexpr_to_column_nodes, aexpr_to_leaf_names, check_input_node, - expr_is_projected_upstream, + aexpr_assign_renamed_leaf, aexpr_to_leaf_names, check_input_node, expr_is_projected_upstream, }; fn init_vec() -> Vec { @@ -134,9 +132,9 @@ fn update_scan_schema( let mut new_schema = Schema::with_capacity(acc_projections.len()); let mut new_cols = Vec::with_capacity(acc_projections.len()); for node in acc_projections.iter() { - for name in aexpr_to_leaf_names(*node, expr_arena) { + for name in aexpr_to_leaf_names_iter(*node, expr_arena) { let item = schema.get_full(&name).ok_or_else(|| { - polars_err!(ComputeError: "column '{}' not available in schema {:?}", name, schema) + polars_err!(ComputeError: "column '{}' not available in 'DataFrame' with {:?}", name, schema) })?; new_cols.push(item); } @@ -151,11 +149,15 @@ fn update_scan_schema( Ok(new_schema) } -pub struct ProjectionPushDown {} +pub struct ProjectionPushDown { + pub is_count_star: bool, +} impl ProjectionPushDown { pub(super) fn new() -> Self { - Self {} + Self { + is_count_star: false, + } } /// Projection will be done at this node, but we continue optimization diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs index 8e714a3a40ca..04fb96b11c1b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs @@ -23,13 +23,33 @@ fn check_double_projection( acc_projections: &mut Vec, projected_names: &mut PlHashSet>, ) { + // Factor out the pruning function + fn prune_projections_by_name( + acc_projections: &mut Vec, + name: &str, + expr_arena: &Arena, + ) { + acc_projections.retain(|expr| { + !aexpr_to_leaf_names_iter(*expr, expr_arena).any(|q| q.as_ref() == name) + }); + } + for (_, ae) in (&*expr_arena).iter(*expr) { - if let AExpr::Alias(_, name) = ae { - if projected_names.remove(name) { - acc_projections - .retain(|expr| !aexpr_to_leaf_names(*expr, expr_arena).contains(name)); - } - } + match ae { + // Series literals come from another source so should not be pushed down. + AExpr::Literal(LiteralValue::Series(s)) => { + let name = s.name(); + if projected_names.remove(name) { + prune_projections_by_name(acc_projections, name, expr_arena) + } + }, + AExpr::Alias(_, name) => { + if projected_names.remove(name) { + prune_projections_by_name(acc_projections, name.as_ref(), expr_arena) + } + }, + _ => {}, + }; } } @@ -64,6 +84,7 @@ pub(super) fn process_projection( } add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); local_projection.push(exprs[0]); + proj_pd.is_count_star = true; } else { // A projection can consist of a chain of expressions followed by an alias. // We want to do the chain locally because it can have complicated side effects. diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index c67603b8e106..1e3f4aa533f1 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -1,67 +1,64 @@ -use polars_utils::arena::Arena; +use polars_utils::floor_divmod::FloorDivMod; +use polars_utils::total_ord::ToTotalOrd; -#[cfg(all(feature = "strings", feature = "concat_str"))] -use crate::dsl::function_expr::StringFunction; -use crate::logical_plan::optimizer::stack_opt::OptimizationRule; use crate::logical_plan::*; use crate::prelude::optimizer::simplify_functions::optimize_functions; macro_rules! eval_binary_same_type { - ($lhs:expr, $operand: tt, $rhs:expr) => {{ - if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { - match (lit_left, lit_right) { - (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { - Some(AExpr::Literal(LiteralValue::Float32(x $operand y))) - } - (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { - Some(AExpr::Literal(LiteralValue::Float64(x $operand y))) - } - #[cfg(feature = "dtype-i8")] - (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { - Some(AExpr::Literal(LiteralValue::Int8(x $operand y))) - } - #[cfg(feature = "dtype-i16")] - (LiteralValue::Int16(x), LiteralValue::Int16(y)) => { - Some(AExpr::Literal(LiteralValue::Int16(x $operand y))) - } - (LiteralValue::Int32(x), LiteralValue::Int32(y)) => { - Some(AExpr::Literal(LiteralValue::Int32(x $operand y))) - } - (LiteralValue::Int64(x), LiteralValue::Int64(y)) => { - Some(AExpr::Literal(LiteralValue::Int64(x $operand y))) - } - #[cfg(feature = "dtype-u8")] - (LiteralValue::UInt8(x), LiteralValue::UInt8(y)) => { - Some(AExpr::Literal(LiteralValue::UInt8(x $operand y))) - } - #[cfg(feature = "dtype-u16")] - (LiteralValue::UInt16(x), LiteralValue::UInt16(y)) => { - Some(AExpr::Literal(LiteralValue::UInt16(x $operand y))) - } - (LiteralValue::UInt32(x), LiteralValue::UInt32(y)) => { - Some(AExpr::Literal(LiteralValue::UInt32(x $operand y))) - } - (LiteralValue::UInt64(x), LiteralValue::UInt64(y)) => { - Some(AExpr::Literal(LiteralValue::UInt64(x $operand y))) + ($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{ + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { + match (lit_left, lit_right) { + (LiteralValue::Float32($l), LiteralValue::Float32($r)) => { + Some(AExpr::Literal(LiteralValue::Float32($ret))) + }, + (LiteralValue::Float64($l), LiteralValue::Float64($r)) => { + Some(AExpr::Literal(LiteralValue::Float64($ret))) + }, + #[cfg(feature = "dtype-i8")] + (LiteralValue::Int8($l), LiteralValue::Int8($r)) => { + Some(AExpr::Literal(LiteralValue::Int8($ret))) + }, + #[cfg(feature = "dtype-i16")] + (LiteralValue::Int16($l), LiteralValue::Int16($r)) => { + Some(AExpr::Literal(LiteralValue::Int16($ret))) + }, + (LiteralValue::Int32($l), LiteralValue::Int32($r)) => { + Some(AExpr::Literal(LiteralValue::Int32($ret))) + }, + (LiteralValue::Int64($l), LiteralValue::Int64($r)) => { + Some(AExpr::Literal(LiteralValue::Int64($ret))) + }, + #[cfg(feature = "dtype-u8")] + (LiteralValue::UInt8($l), LiteralValue::UInt8($r)) => { + Some(AExpr::Literal(LiteralValue::UInt8($ret))) + }, + #[cfg(feature = "dtype-u16")] + (LiteralValue::UInt16($l), LiteralValue::UInt16($r)) => { + Some(AExpr::Literal(LiteralValue::UInt16($ret))) + }, + (LiteralValue::UInt32($l), LiteralValue::UInt32($r)) => { + Some(AExpr::Literal(LiteralValue::UInt32($ret))) + }, + (LiteralValue::UInt64($l), LiteralValue::UInt64($r)) => { + Some(AExpr::Literal(LiteralValue::UInt64($ret))) + }, + _ => None, } - _ => None, + } else { + None } - } else { - None - } - - }} + }}; } -macro_rules! eval_binary_bool_type { +macro_rules! eval_binary_cmp_same_type { ($lhs:expr, $operand: tt, $rhs:expr) => {{ if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { match (lit_left, lit_right) { (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } #[cfg(feature = "dtype-i8")] (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { @@ -229,12 +226,49 @@ impl OptimizationRule for SimplifyBooleanRule { { Some(AExpr::Literal(LiteralValue::Boolean(true))) }, + AExpr::Function { + input, + function: FunctionExpr::Negate, + .. + } if input.len() == 1 => { + let input = input[0]; + let ae = expr_arena.get(input); + eval_negate(ae) + }, + // Flatten Aliases. + AExpr::Alias(inner, name) => { + let input = expr_arena.get(*inner); + + if let AExpr::Alias(input, _) = input { + Some(AExpr::Alias(*input, name.clone())) + } else { + None + } + }, _ => None, }; Ok(out) } } +fn eval_negate(ae: &AExpr) -> Option { + let out = match ae { + AExpr::Literal(lv) => match lv { + #[cfg(feature = "dtype-i8")] + LiteralValue::Int8(v) => LiteralValue::Int8(-*v), + #[cfg(feature = "dtype-i16")] + LiteralValue::Int16(v) => LiteralValue::Int16(-*v), + LiteralValue::Int32(v) => LiteralValue::Int32(-*v), + LiteralValue::Int64(v) => LiteralValue::Int64(-*v), + LiteralValue::Float32(v) => LiteralValue::Float32(-*v), + LiteralValue::Float64(v) => LiteralValue::Float64(-*v), + _ => return None, + }, + _ => return None, + }; + Some(AExpr::Literal(out)) +} + fn eval_bitwise(left: &AExpr, right: &AExpr, operation: F) -> Option where F: Fn(bool, bool) -> bool, @@ -396,9 +430,9 @@ impl OptimizationRule for SimplifyExprRule { _lp_arena: &Arena, _lp_node: Node, ) -> PolarsResult> { - let expr = expr_arena.get(expr_node); + let expr = expr_arena.get(expr_node).clone(); - let out = match expr { + let out = match &expr { // lit(left) + lit(right) => lit(left + right) // and null propagation AExpr::BinaryExpr { left, op, right } => { @@ -410,7 +444,7 @@ impl OptimizationRule for SimplifyExprRule { #[allow(clippy::manual_map)] let out = match op { Plus => { - match eval_binary_same_type!(left_aexpr, +, right_aexpr) { + match eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + r) { Some(new) => Some(new), None => { // try to replace addition of string columns with `concat_str` @@ -433,9 +467,61 @@ impl OptimizationRule for SimplifyExprRule { }, } }, - Minus => eval_binary_same_type!(left_aexpr, -, right_aexpr), - Multiply => eval_binary_same_type!(left_aexpr, *, right_aexpr), - Divide => eval_binary_same_type!(left_aexpr, /, right_aexpr), + Minus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l - r), + Multiply => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l * r), + Divide => { + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = + (left_aexpr, right_aexpr) + { + match (lit_left, lit_right) { + (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { + Some(AExpr::Literal(LiteralValue::Float32(x / y))) + }, + (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { + Some(AExpr::Literal(LiteralValue::Float64(x / y))) + }, + #[cfg(feature = "dtype-i8")] + (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { + Some(AExpr::Literal(LiteralValue::Int8( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + #[cfg(feature = "dtype-i16")] + (LiteralValue::Int16(x), LiteralValue::Int16(y)) => { + Some(AExpr::Literal(LiteralValue::Int16( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + (LiteralValue::Int32(x), LiteralValue::Int32(y)) => { + Some(AExpr::Literal(LiteralValue::Int32( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + (LiteralValue::Int64(x), LiteralValue::Int64(y)) => { + Some(AExpr::Literal(LiteralValue::Int64( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + #[cfg(feature = "dtype-u8")] + (LiteralValue::UInt8(x), LiteralValue::UInt8(y)) => { + Some(AExpr::Literal(LiteralValue::UInt8(x / y))) + }, + #[cfg(feature = "dtype-u16")] + (LiteralValue::UInt16(x), LiteralValue::UInt16(y)) => { + Some(AExpr::Literal(LiteralValue::UInt16(x / y))) + }, + (LiteralValue::UInt32(x), LiteralValue::UInt32(y)) => { + Some(AExpr::Literal(LiteralValue::UInt32(x / y))) + }, + (LiteralValue::UInt64(x), LiteralValue::UInt64(y)) => { + Some(AExpr::Literal(LiteralValue::UInt64(x / y))) + }, + _ => None, + } + } else { + None + } + }, TrueDivide => { if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = (left_aexpr, right_aexpr) @@ -481,17 +567,23 @@ impl OptimizationRule for SimplifyExprRule { None } }, - Modulus => eval_binary_same_type!(left_aexpr, %, right_aexpr), - Lt => eval_binary_bool_type!(left_aexpr, <, right_aexpr), - Gt => eval_binary_bool_type!(left_aexpr, >, right_aexpr), - Eq | EqValidity => eval_binary_bool_type!(left_aexpr, ==, right_aexpr), - NotEq | NotEqValidity => eval_binary_bool_type!(left_aexpr, !=, right_aexpr), - GtEq => eval_binary_bool_type!(left_aexpr, >=, right_aexpr), - LtEq => eval_binary_bool_type!(left_aexpr, <=, right_aexpr), + Modulus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(*r) + .1), + Lt => eval_binary_cmp_same_type!(left_aexpr, <, right_aexpr), + Gt => eval_binary_cmp_same_type!(left_aexpr, >, right_aexpr), + Eq | EqValidity => eval_binary_cmp_same_type!(left_aexpr, ==, right_aexpr), + NotEq | NotEqValidity => { + eval_binary_cmp_same_type!(left_aexpr, !=, right_aexpr) + }, + GtEq => eval_binary_cmp_same_type!(left_aexpr, >=, right_aexpr), + LtEq => eval_binary_cmp_same_type!(left_aexpr, <=, right_aexpr), And | LogicalAnd => eval_bitwise(left_aexpr, right_aexpr, |l, r| l & r), Or | LogicalOr => eval_bitwise(left_aexpr, right_aexpr, |l, r| l | r), Xor => eval_bitwise(left_aexpr, right_aexpr, |l, r| l ^ r), - FloorDivide => None, + FloorDivide => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(*r) + .0), }; if out.is_some() { return Ok(out); @@ -534,7 +626,7 @@ fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult { - let Some(av) = lv.to_anyvalue() else { + let Some(av) = lv.to_any_value() else { return Ok(None); }; if dtype == &av.dtype() { diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs index 5ec3655effa1..70d7d57e413a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs @@ -3,8 +3,8 @@ use super::*; pub(super) fn optimize_functions( input: &[Node], function: &FunctionExpr, - _options: &FunctionOptions, - expr_arena: &Arena, + options: &FunctionOptions, + expr_arena: &mut Arena, ) -> PolarsResult> { let out = match function { // sort().reverse() -> sort(reverse) @@ -54,7 +54,7 @@ pub(super) fn optimize_functions( Some(AExpr::Function { input: new_inputs, function: function.clone(), - options: *_options, + options: *options, }) } else { None @@ -62,15 +62,55 @@ pub(super) fn optimize_functions( }, FunctionExpr::Boolean(BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal) => { if input.len() == 1 { - Some(expr_arena.get(input[0]).clone()) + Some(AExpr::Cast { + expr: input[0], + data_type: DataType::Boolean, + strict: false, + }) } else { None } }, FunctionExpr::Boolean(BooleanFunction::Not) => { - let y = expr_arena.get(input[0]); + let y = expr_arena.get(input[0]).clone(); match y { + // not(a and b) => not(a) or not(b) + AExpr::BinaryExpr { + left, + op: Operator::And | Operator::LogicalAnd, + right, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::Function { + input: vec![left], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + op: Operator::Or, + right: expr_arena.add(AExpr::Function { + input: vec![right], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + }), + // not(a or b) => not(a) and not(b) + AExpr::BinaryExpr { + left, + op: Operator::Or | Operator::LogicalOr, + right, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::Function { + input: vec![left], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + op: Operator::And, + right: expr_arena.add(AExpr::Function { + input: vec![right], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + }), // not(not x) => x AExpr::Function { input, @@ -81,6 +121,123 @@ pub(super) fn optimize_functions( AExpr::Literal(LiteralValue::Boolean(b)) => { Some(AExpr::Literal(LiteralValue::Boolean(!b))) }, + // not(x.is_null) => x.is_not_null + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options, + } => Some(AExpr::Function { + input: input.clone(), + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options, + }), + // not(x.is_not_null) => x.is_null + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options, + } => Some(AExpr::Function { + input: input.clone(), + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options, + }), + // not(a == b) => a != b + AExpr::BinaryExpr { + left, + op: Operator::Eq, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::NotEq, + right, + }), + // not(a != b) => a == b + AExpr::BinaryExpr { + left, + op: Operator::NotEq, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::Eq, + right, + }), + // not(a < b) => a >= b + AExpr::BinaryExpr { + left, + op: Operator::Lt, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::GtEq, + right, + }), + // not(a <= b) => a > b + AExpr::BinaryExpr { + left, + op: Operator::LtEq, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::Gt, + right, + }), + // not(a > b) => a <= b + AExpr::BinaryExpr { + left, + op: Operator::Gt, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::LtEq, + right, + }), + // not(a >= b) => a < b + AExpr::BinaryExpr { + left, + op: Operator::GtEq, + right, + } => Some(AExpr::BinaryExpr { + left, + op: Operator::Lt, + right, + }), + #[cfg(feature = "is_between")] + // not(col('x').is_between(a,b)) => col('x') < a || col('x') > b + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }), + .. + } => { + if !matches!(expr_arena.get(input[0]), AExpr::Column(_)) { + None + } else { + let left_cmp_op = match closed { + ClosedInterval::Both | ClosedInterval::Left => Operator::Lt, + ClosedInterval::None | ClosedInterval::Right => Operator::LtEq, + }; + let right_cmp_op = match closed { + ClosedInterval::Both | ClosedInterval::Right => Operator::Gt, + ClosedInterval::None | ClosedInterval::Left => Operator::GtEq, + }; + // input[0] is between input[1] and input[2] + Some(AExpr::BinaryExpr { + // input[0] (<,<=) input[1] + left: expr_arena.add(AExpr::BinaryExpr { + left: input[0], + op: left_cmp_op, + right: input[1], + }), + // OR + op: Operator::Or, + // input[0] (>,>=) input[2] + right: expr_arena.add(AExpr::BinaryExpr { + left: input[0], + op: right_cmp_op, + right: input[2], + }), + }) + } + }, _ => None, } }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 74d55baeb068..3560ec2df01e 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; +use crate::logical_plan::projection_expr::ProjectionExprs; use crate::prelude::*; pub(super) struct SlicePushDown { @@ -13,6 +14,49 @@ struct State { len: IdxSize, } +/// Can push down slice when: +/// * all projections are elementwise +/// * at least 1 projection is based on a column (for height broadcast) +/// * projections not based on any column project as scalars +/// +/// Returns (all_elementwise, all_elementwise_and_any_expr_has_column) +fn can_pushdown_slice_past_projections( + exprs: &ProjectionExprs, + arena: &Arena, +) -> (bool, bool) { + let mut all_elementwise_and_any_expr_has_column = false; + for node in exprs.iter() { + // `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown, + // because `c` projects to a height independent from the input height. We check + // this by observing that `c` does not have any columns in its input notes. + // + // TODO: Simply checking that a column node is present does not handle e.g.: + // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, + // `str.contains`, `str.contains_many` etc. - observe a column node is present + // but the output height is not dependent on it. + let mut has_column = false; + let mut literals_all_scalar = true; + let is_elementwise = arena.iter(*node).all(|(_node, ae)| { + has_column |= matches!(ae, AExpr::Column(_)); + literals_all_scalar &= if let AExpr::Literal(v) = ae { + v.projects_as_scalar() + } else { + true + }; + single_aexpr_is_elementwise(ae) + }); + + // If there is no column then all literals must be scalar + if !is_elementwise || !(has_column || literals_all_scalar) { + return (false, false); + } + + all_elementwise_and_any_expr_has_column |= has_column + } + + (true, all_elementwise_and_any_expr_has_column) +} + impl SlicePushDown { pub(super) fn new(streaming: bool) -> Self { Self { @@ -322,10 +366,7 @@ impl SlicePushDown { } // there is state, inspect the projection to determine how to deal with it (Projection {input, expr, schema, options}, Some(_)) => { - // The slice operation may only pass on simple projections. col("foo").alias("bar") - if expr.iter().all(|root| { - aexpr_is_elementwise(*root, expr_arena) - }) { + if can_pushdown_slice_past_projections(&expr, expr_arena).1 { let lp = Projection {input, expr, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } @@ -335,12 +376,16 @@ impl SlicePushDown { self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) } } - // this is copied from `Projection` (HStack {input, exprs, schema, options}, _) => { - // The slice operation may only pass on simple projections. col("foo").alias("bar") - if exprs.iter().all(|root| { - aexpr_is_elementwise(*root, expr_arena) - }) { + let check = can_pushdown_slice_past_projections(&exprs, expr_arena); + + if ( + // If the schema length is greater then an input column is being projected, so + // the exprs in with_columns do not need to have an input column name. + schema.len() > exprs.len() && check.0 + ) + || check.1 // e.g. select(c).with_columns(c = c + 1) + { let lp = HStack {input, exprs, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index df9c15885968..91a6d812a616 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -4,12 +4,11 @@ use std::borrow::Cow; use polars_core::prelude::*; use polars_core::utils::get_supertype; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; use super::*; -use crate::dsl::function_expr::FunctionExpr; use crate::logical_plan::optimizer::type_coercion::binary::process_binary; -use crate::logical_plan::Context; -use crate::utils::is_scan; pub struct TypeCoercionRule {} @@ -196,7 +195,7 @@ fn modify_supertype( // cast literal to right type if they fit in the range (Literal(value), _) => { - if let Some(lit_val) = value.to_anyvalue() { + if let Some(lit_val) = value.to_any_value() { if type_right.value_within_range(lit_val) { st = type_right.clone(); } @@ -204,7 +203,7 @@ fn modify_supertype( }, // cast literal to left type (_, Literal(value)) => { - if let Some(lit_val) = value.to_anyvalue() { + if let Some(lit_val) = value.to_any_value() { if type_left.value_within_range(lit_val) { st = type_left.clone(); } @@ -231,7 +230,11 @@ fn modify_supertype( | (List(other), List(inner), AExpr::Literal(_), _) if inner != other => { - st = DataType::List(inner.clone()) + st = match &**inner { + #[cfg(feature = "dtype-categorical")] + Categorical(_, ordering) => List(Box::new(Categorical(None, *ordering))), + _ => List(inner.clone()), + }; }, // do nothing _ => {}, @@ -240,13 +243,13 @@ fn modify_supertype( st } -fn get_input(lp_arena: &Arena, lp_node: Node) -> [Option; 2] { +fn get_input(lp_arena: &Arena, lp_node: Node) -> UnitVec { let plan = lp_arena.get(lp_node); - let mut inputs = [None, None]; + let mut inputs: UnitVec = unitvec!(); // Used to get the schema of the input. if is_scan(plan) { - inputs[0] = Some(lp_node); + inputs.push(lp_node); } else { plan.copy_inputs(&mut inputs); }; @@ -254,10 +257,13 @@ fn get_input(lp_arena: &Arena, lp_node: Node) -> [Option; 2] } fn get_schema(lp_arena: &Arena, lp_node: Node) -> Cow<'_, SchemaRef> { - match get_input(lp_arena, lp_node) { - [Some(input), _] => lp_arena.get(input).schema(lp_arena), - // files don't have an input, so we must take their schema - [None, _] => Cow::Borrowed(lp_arena.get(lp_node).scan_schema()), + let inputs = get_input(lp_arena, lp_node); + if inputs.is_empty() { + // Files don't have an input, so we must take their schema. + Cow::Borrowed(lp_arena.get(lp_node).scan_schema()) + } else { + let input = inputs[0]; + lp_arena.get(input).schema(lp_arena) } } @@ -363,6 +369,10 @@ impl OptimizationRule for TypeCoercionRule { (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => { return Ok(None) }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + return Ok(None) + }, #[cfg(feature = "dtype-decimal")] (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index d07e92eadf25..cb1325104fd3 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -39,6 +39,7 @@ pub struct CsvParserOptions { pub try_parse_dates: bool, pub raise_if_empty: bool, pub truncate_ragged_lines: bool, + pub n_threads: Option, } #[cfg(feature = "parquet")] diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 9c8968b7a2f7..5d7924ed027f 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -1,10 +1,7 @@ //! this contains code used for rewriting projections, expanding wildcards, regex selection etc. -use arrow::legacy::index::IndexToUsize; use polars_core::utils::get_supertype; use super::*; -use crate::prelude::function_expr::FunctionExpr; -use crate::utils::expr_output_name; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. diff --git a/crates/polars-plan/src/logical_plan/pyarrow.rs b/crates/polars-plan/src/logical_plan/pyarrow.rs index 0aa7d236750a..91ba9aca3ea9 100644 --- a/crates/polars-plan/src/logical_plan/pyarrow.rs +++ b/crates/polars-plan/src/logical_plan/pyarrow.rs @@ -65,7 +65,7 @@ pub(super) fn predicate_to_pa( } }, AExpr::Literal(lv) => { - let av = lv.to_anyvalue()?; + let av = lv.to_any_value()?; let dtype = av.dtype(); match av.as_borrowed() { AnyValue::String(s) => Some(format!("'{s}'")), diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index 00b49fd7843f..a4d8f8eb2f02 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -1,6 +1,9 @@ -use std::fmt::{Display, Formatter, UpperExp}; +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter, UpperExp}; use polars_core::error::*; +#[cfg(feature = "regex")] +use regex::Regex; use crate::logical_plan::visitor::{VisitRecursion, Visitor}; use crate::prelude::visitor::AexprNode; @@ -61,22 +64,244 @@ impl UpperExp for AExpr { } } -pub(crate) struct TreeFmtVisitor { - levels: Vec>, - depth: u32, - width: u32, +pub enum TreeFmtNode<'a> { + Expression(Option, &'a Expr), + LogicalPlan(Option, &'a LogicalPlan), } -impl TreeFmtVisitor { - pub(crate) fn new() -> Self { - Self { - levels: vec![], - depth: 0, - width: 0, +struct TreeFmtNodeData<'a>(String, Vec>); + +fn with_header(header: &Option, text: &str) -> String { + if let Some(header) = header { + format!("{header}\n{text}") + } else { + text.to_string() + } +} + +#[cfg(feature = "regex")] +fn multiline_expression(expr: &str) -> Cow<'_, str> { + let re = Regex::new(r"([\)\]])(\.[a-z0-9]+\()").unwrap(); + re.replace_all(expr, "$1\n $2") +} + +impl<'a> TreeFmtNode<'a> { + pub fn root_logical_plan(lp: &'a LogicalPlan) -> Self { + Self::LogicalPlan(None, lp) + } + + pub fn traverse(&self, visitor: &mut TreeFmtVisitor) { + let TreeFmtNodeData(title, child_nodes) = self.node_data(); + + if visitor.levels.len() <= visitor.depth { + visitor.levels.push(vec![]); + } + + let row = visitor.levels.get_mut(visitor.depth).unwrap(); + row.resize(visitor.width + 1, "".to_string()); + + row[visitor.width] = title; + visitor.prev_depth = visitor.depth; + visitor.depth += 1; + + for child in &child_nodes { + child.traverse(visitor); + } + + visitor.depth -= 1; + visitor.width += if visitor.prev_depth == visitor.depth { + 1 + } else { + 0 + }; + } + + fn node_data(&self) -> TreeFmtNodeData<'_> { + use LogicalPlan::*; + use TreeFmtNode::{Expression as NE, LogicalPlan as NL}; + use {with_header as wh, TreeFmtNodeData as ND}; + + match self { + #[cfg(feature = "regex")] + NE(h, expr) => ND(wh(h, &multiline_expression(&format!("{expr:?}"))), vec![]), + #[cfg(not(feature = "regex"))] + NE(h, expr) => ND(wh(h, &format!("{expr:?}")), vec![]), + #[cfg(feature = "python")] + NL(h, lp @ PythonScan { .. }) => ND(wh(h, &format!("{lp:?}",)), vec![]), + NL(h, lp @ Scan { .. }) => ND(wh(h, &format!("{lp:?}",)), vec![]), + NL( + h, + DataFrameScan { + schema, + projection, + selection, + .. + }, + ) => ND( + wh( + h, + &format!( + "DF {:?}\nPROJECT {}/{} COLUMNS", + schema.iter_names().take(4).collect::>(), + if let Some(columns) = projection { + format!("{}", columns.len()) + } else { + "*".to_string() + }, + schema.len() + ), + ), + if let Some(expr) = selection { + vec![NE(Some("SELECTION:".to_string()), expr)] + } else { + vec![] + }, + ), + NL(h, Union { inputs, options }) => ND( + wh( + h, + &(if let Some(slice) = options.slice { + format!("SLICED UNION: {slice:?}") + } else { + "UNION".to_string() + }), + ), + inputs + .iter() + .enumerate() + .map(|(i, lp)| NL(Some(format!("PLAN {i}:")), lp)) + .collect(), + ), + NL(h, HConcat { inputs, .. }) => ND( + wh(h, "HCONCAT"), + inputs + .iter() + .enumerate() + .map(|(i, lp)| NL(Some(format!("PLAN {i}:")), lp)) + .collect(), + ), + NL(h, Cache { input, id, count }) => ND( + wh(h, &format!("CACHE[id: {:x}, count: {}]", *id, *count)), + vec![NL(None, input)], + ), + NL(h, Selection { input, predicate }) => ND( + wh(h, "FILTER"), + vec![ + NE(Some("predicate:".to_string()), predicate), + NL(Some("FROM:".to_string()), input), + ], + ), + NL(h, Projection { expr, input, .. }) => ND( + wh(h, "SELECT"), + expr.iter() + .map(|expr| NE(Some("expression:".to_string()), expr)) + .chain([NL(Some("FROM:".to_string()), input)]) + .collect(), + ), + NL( + h, + LogicalPlan::Sort { + input, by_column, .. + }, + ) => ND( + wh(h, "SORT BY"), + by_column + .iter() + .map(|expr| NE(Some("expression:".to_string()), expr)) + .chain([NL(None, input)]) + .collect(), + ), + NL( + h, + Aggregate { + input, keys, aggs, .. + }, + ) => ND( + wh(h, "AGGREGATE"), + aggs.iter() + .map(|expr| NE(Some("expression:".to_string()), expr)) + .chain( + keys.iter() + .map(|expr| NE(Some("aggregate by:".to_string()), expr)), + ) + .chain([NL(Some("FROM:".to_string()), input)]) + .collect(), + ), + NL( + h, + Join { + input_left, + input_right, + left_on, + right_on, + options, + .. + }, + ) => ND( + wh(h, &format!("{} JOIN", options.args.how)), + left_on + .iter() + .map(|expr| NE(Some("left on:".to_string()), expr)) + .chain([NL(Some("LEFT PLAN:".to_string()), input_left)]) + .chain( + right_on + .iter() + .map(|expr| NE(Some("right on:".to_string()), expr)), + ) + .chain([NL(Some("RIGHT PLAN:".to_string()), input_right)]) + .collect(), + ), + NL(h, HStack { input, exprs, .. }) => ND( + wh(h, "WITH_COLUMNS"), + exprs + .iter() + .map(|expr| NE(Some("expression:".to_string()), expr)) + .chain([NL(None, input)]) + .collect(), + ), + NL(h, Distinct { input, options }) => ND( + wh( + h, + &format!( + "UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + options.maintain_order, options.keep_strategy, options.subset + ), + ), + vec![NL(None, input)], + ), + NL(h, LogicalPlan::Slice { input, offset, len }) => ND( + wh(h, &format!("SLICE[offset: {offset}, len: {len}]")), + vec![NL(None, input)], + ), + NL(h, MapFunction { input, function }) => { + ND(wh(h, &format!("{function}")), vec![NL(None, input)]) + }, + NL(h, Error { input, err }) => ND(wh(h, &format!("{err:?}")), vec![NL(None, input)]), + NL(h, ExtContext { input, .. }) => ND(wh(h, "EXTERNAL_CONTEXT"), vec![NL(None, input)]), + NL(h, Sink { input, payload }) => ND( + wh( + h, + match payload { + SinkType::Memory => "SINK (memory)", + SinkType::File { .. } => "SINK (file)", + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => "SINK (cloud)", + }, + ), + vec![NL(None, input)], + ), } } } +#[derive(Default)] +pub(crate) struct TreeFmtVisitor { + levels: Vec>, + prev_depth: usize, + depth: usize, + width: usize, +} + impl Visitor for TreeFmtVisitor { type Node = AexprNode; @@ -85,16 +310,20 @@ impl Visitor for TreeFmtVisitor { let ae = node.to_aexpr(); let repr = format!("{:E}", ae); - if self.levels.len() <= self.depth as usize { + if self.levels.len() <= self.depth { self.levels.push(vec![]) } // the post-visit ensures the width of this node is known - let row = self.levels.get_mut(self.depth as usize).unwrap(); + let row = self.levels.get_mut(self.depth).unwrap(); // set default values to ensure we format at the right width - row.resize(self.width as usize + 1, "".to_string()); - row[self.width as usize] = repr; + row.resize(self.width + 1, "".to_string()); + row[self.width] = repr; + + // before entering a depth-first branch we preserve the depth to control the width increase + // in the post-visit + self.prev_depth = self.depth; // we will enter depth first, we enter child so depth increases self.depth += 1; @@ -103,217 +332,458 @@ impl Visitor for TreeFmtVisitor { } fn post_visit(&mut self, _node: &Self::Node) -> PolarsResult { - // because we traverse depth first - // every post-visit increases the width as we finished a depth-first branch - self.width += 1; - // we finished this branch so we decrease in depth, back the caller node self.depth -= 1; + + // because we traverse depth first + // the width is increased once after one or more depth-first branches + // this way we avoid empty columns in the resulting tree representation + self.width += if self.prev_depth == self.depth { 1 } else { 0 }; + Ok(VisitRecursion::Continue) } } -fn format_levels(f: &mut Formatter<'_>, levels: &[Vec]) -> std::fmt::Result { - let n_cols = levels.iter().map(|v| v.len()).max().unwrap(); +/// Calculates the number of digits in a `usize` number +/// Useful for the alignment of `usize` values when they are displayed +fn digits(n: usize) -> usize { + if n == 0 { + 1 + } else { + f64::log10(n as f64) as usize + 1 + } +} - let mut col_widths = vec![0usize; n_cols]; +/// Meta-info of a column in a populated `TreeFmtVisitor` required for the pretty-print of a tree +#[derive(Clone, Default, Debug)] +struct TreeViewColumn { + offset: usize, + width: usize, + center: usize, +} - let row_idx_width = levels.len().to_string().len() + 1; - let col_idx_width = n_cols.to_string().len(); - let space = " "; - let dash = "─"; +/// Meta-info of a column in a populated `TreeFmtVisitor` required for the pretty-print of a tree +#[derive(Clone, Default, Debug)] +struct TreeViewRow { + offset: usize, + height: usize, + center: usize, +} - for (i, col_width) in col_widths.iter_mut().enumerate() { - *col_width = levels - .iter() - .map(|row| row.get(i).map(|s| s.as_str()).unwrap_or("").chars().count()) - .max() - .map(|n| if n < col_idx_width { col_idx_width } else { n }) - .unwrap(); - } +/// Meta-info of a cell in a populated `TreeFmtVisitor` +#[derive(Clone, Default, Debug)] +struct TreeViewCell<'a> { + text: Vec<&'a str>, + /// A `Vec` of indices of `TreeViewColumn`-s stored elsewhere in another `Vec` + /// For a cell on a row `i` these indices point to the columns that contain child-cells on a + /// row `i + 1` (if the latter exists) + /// NOTE: might warrant a rethink should this code become used broader + children_columns: Vec, +} - const COL_SPACING: usize = 2; +/// The complete intermediate representation of a `TreeFmtVisitor` that can be drawn on a `Canvas` +/// down the line +#[derive(Default, Debug)] +struct TreeView<'a> { + n_rows: usize, + n_rows_width: usize, + matrix: Vec>>, + /// NOTE: `TreeViewCell`'s `children_columns` field contains indices pointing at the elements + /// of this `Vec` + columns: Vec, + rows: Vec, +} - for (row_count, row) in levels.iter().enumerate() { - if row_count == 0 { - // write the col numbers - writeln!(f)?; - write!(f, "{space:>row_idx_width$} ")?; - for (col_i, (_, col_width)) in - levels.last().unwrap().iter().zip(&col_widths).enumerate() - { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; +// NOTE: the code below this line is full of hardcoded integer offsets which may not be a big +// problem as long as it remains the private implementation of the pretty-print +/// The conversion from a reference to `levels` field of a `TreeFmtVisitor` +impl<'a> From<&'a [Vec]> for TreeView<'a> { + #[allow(clippy::needless_range_loop)] + fn from(value: &'a [Vec]) -> Self { + let n_rows = value.len(); + let n_cols = value.iter().map(|row| row.len()).max().unwrap_or(0); + if n_rows == 0 || n_cols == 0 { + return TreeView::default(); + } + // the character-width of the highest index of a row + let n_rows_width = digits(n_rows - 1); + + let mut matrix = vec![vec![TreeViewCell::default(); n_cols]; n_rows]; + for i in 0..n_rows { + for j in 0..n_cols { + if j < value[i].len() && !value[i][j].is_empty() { + matrix[i][j].text = value[i][j].split('\n').collect(); + if i < n_rows - 1 { + if j < value[i + 1].len() && !value[i + 1][j].is_empty() { + matrix[i][j].children_columns.push(j); + } + for k in j + 1..n_cols { + if (k >= value[i].len() || value[i][k].is_empty()) + && k < value[i + 1].len() + { + if !value[i + 1][k].is_empty() { + matrix[i][j].children_columns.push(k); + } + } else { + break; + } + } + } } - let half = (col_spacing + 4) / 2; - let remaining = col_spacing + 4 - half; + } + } - // left_half - write!(f, "{space:^half$}")?; - // col num - write!(f, "{col_i:^col_width$}")?; + let mut y_offset = 3; + let mut rows = vec![TreeViewRow::default(); n_rows]; + for i in 0..n_rows { + let mut height = 0; + for j in 0..n_cols { + height = [matrix[i][j].text.len(), height].into_iter().max().unwrap(); + } + height += 2; + rows[i].offset = y_offset; + rows[i].height = height; + rows[i].center = height / 2; + y_offset += height + 3; + } - write!(f, "{space:^remaining$}")?; + let mut x_offset = n_rows_width + 4; + let mut columns = vec![TreeViewColumn::default(); n_cols]; + // the two nested loops below are those `needless_range_loop`s + // more readable this way to my taste + for j in 0..n_cols { + let mut width = 0; + for i in 0..n_rows { + width = [ + matrix[i][j].text.iter().map(|l| l.len()).max().unwrap_or(0), + width, + ] + .into_iter() + .max() + .unwrap(); } - writeln!(f)?; + width += 6; + columns[j].offset = x_offset; + columns[j].width = width; + columns[j].center = width / 2 + width % 2; + x_offset += width; + } - // write the horizontal line - write!(f, "{space:>row_idx_width$} ┌")?; - for (col_i, (_, col_width)) in - levels.last().unwrap().iter().zip(&col_widths).enumerate() - { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; - } - write!(f, "{dash:─^width$}", width = col_width + col_spacing + 4)?; + Self { + n_rows, + n_rows_width, + matrix, + columns, + rows, + } + } +} + +/// The basic charset that's used for drawing lines and boxes on a `Canvas` +struct Glyphs { + void: char, + vertical_line: char, + horizontal_line: char, + top_left_corner: char, + top_right_corner: char, + bottom_left_corner: char, + bottom_right_corner: char, + tee_down: char, + tee_up: char, +} + +impl Default for Glyphs { + fn default() -> Self { + Self { + void: ' ', + vertical_line: '│', + horizontal_line: '─', + top_left_corner: '╭', + top_right_corner: '╮', + bottom_left_corner: '╰', + bottom_right_corner: '╯', + tee_down: '┬', + tee_up: '┴', + } + } +} + +/// A `Point` on a `Canvas` +#[derive(Clone, Copy)] +struct Point(usize, usize); + +/// The orientation of a line on a `Canvas` +#[derive(Clone, Copy)] +enum Orientation { + Vertical, + Horizontal, +} + +/// `Canvas` +struct Canvas { + width: usize, + height: usize, + canvas: Vec>, + glyphs: Glyphs, +} + +impl Canvas { + fn new(width: usize, height: usize, glyphs: Glyphs) -> Self { + Self { + width, + height, + canvas: vec![vec![glyphs.void; width]; height], + glyphs, + } + } + + /// Draws a single `symbol` on the `Canvas` + /// NOTE: The `Point`s that lay outside of the `Canvas` are quietly ignored + fn draw_symbol(&mut self, point: Point, symbol: char) { + let Point(x, y) = point; + if x < self.width && y < self.height { + self.canvas[y][x] = symbol; + } + } + + /// Draws a line of `length` from an `origin` along the `orientation` + fn draw_line(&mut self, origin: Point, orientation: Orientation, length: usize) { + let Point(x, y) = origin; + if let Orientation::Vertical = orientation { + let mut down = 0; + while down < length { + self.draw_symbol(Point(x, y + down), self.glyphs.vertical_line); + down += 1; } - write!(f, "\n{space:>row_idx_width$} │\n")?; - } else { - // write connecting lines - write!(f, "{space:>row_idx_width$} │")?; - let mut last_empty = true; - let mut before = ""; - for ((col_i, col_name), col_width) in row.iter().enumerate().zip(&col_widths) { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; - } + } else if let Orientation::Horizontal = orientation { + let mut right = 0; + while right < length { + self.draw_symbol(Point(x + right, y), self.glyphs.horizontal_line); + right += 1; + } + } + } - let half = (*col_width + col_spacing + 4) / 2; - let remaining = col_width + col_spacing + 4 - half - 1; - if last_empty { - // left_half - write!(f, "{space:^half$}")?; - // bar - if col_name.is_empty() { - write!(f, " ")?; - } else { - write!(f, "│")?; - last_empty = false; - before = "│"; + /// Draws a box of `width` and `height` with an `origin` being the top left corner + fn draw_box(&mut self, origin: Point, width: usize, height: usize) { + let Point(x, y) = origin; + self.draw_symbol(origin, self.glyphs.top_left_corner); + self.draw_symbol(Point(x + width - 1, y), self.glyphs.top_right_corner); + self.draw_symbol(Point(x, y + height - 1), self.glyphs.bottom_left_corner); + self.draw_symbol( + Point(x + width - 1, y + height - 1), + self.glyphs.bottom_right_corner, + ); + self.draw_line(Point(x + 1, y), Orientation::Horizontal, width - 2); + self.draw_line( + Point(x + 1, y + height - 1), + Orientation::Horizontal, + width - 2, + ); + self.draw_line(Point(x, y + 1), Orientation::Vertical, height - 2); + self.draw_line( + Point(x + width - 1, y + 1), + Orientation::Vertical, + height - 2, + ); + } + + /// Draws a box of height `2 + text.len()` containing a left-aligned text + fn draw_label_centered(&mut self, center: Point, text: &[&str]) { + if !text.is_empty() { + let Point(x, y) = center; + let text_width = text.iter().map(|l| l.len()).max().unwrap(); + let half_width = text_width / 2 + text_width % 2; + let half_height = text.len() / 2; + if x >= half_width + 2 && y > half_height { + self.draw_box( + Point(x - half_width - 2, y - half_height - 1), + text_width + 4, + text.len() + 2, + ); + for (i, line) in text.iter().enumerate() { + for (j, c) in line.chars().enumerate() { + self.draw_symbol(Point(x - half_width + j, y - half_height + i), c); } - } else { - // left_half - write!(f, "{dash:─^half$}")?; - // bar - write!(f, "╮")?; - before = "╮" } - if (col_i == row.len() - 1) | col_name.is_empty() { - write!(f, "{space:^remaining$}")?; + } + } + } + + /// Draws branched lines from a `Point` to multiple `Point`s below + /// NOTE: the shape of these connections is very specific for this particular kind of the + /// representation of a tree + fn draw_connections(&mut self, from: Point, to: &[Point], branching_offset: usize) { + let mut start_with_corner = true; + let Point(mut x_from, mut y_from) = from; + for (i, Point(x, y)) in to.iter().enumerate() { + if *x >= x_from && *y >= y_from - 1 { + self.draw_symbol(Point(*x, *y), self.glyphs.tee_up); + if *x == x_from { + // if the first connection goes straight below + self.draw_symbol(Point(x_from, y_from - 1), self.glyphs.tee_down); + self.draw_line(Point(x_from, y_from), Orientation::Vertical, *y - y_from); + x_from += 1; } else { - if before == "│" { - write!(f, " ╰")?; + if start_with_corner { + // if the first or the second connection steers to the right + self.draw_symbol(Point(x_from, y_from - 1), self.glyphs.tee_down); + self.draw_line( + Point(x_from, y_from), + Orientation::Vertical, + branching_offset, + ); + y_from += branching_offset; + self.draw_symbol(Point(x_from, y_from), self.glyphs.bottom_left_corner); + start_with_corner = false; + x_from += 1; + } + let length = *x - x_from; + self.draw_line(Point(x_from, y_from), Orientation::Horizontal, length); + x_from += length; + if i == to.len() - 1 { + self.draw_symbol(Point(x_from, y_from), self.glyphs.top_right_corner); } else { - write!(f, "──")?; + self.draw_symbol(Point(x_from, y_from), self.glyphs.tee_down); } - write!(f, "{dash:─^width$}", width = remaining - 2)?; + self.draw_line( + Point(x_from, y_from + 1), + Orientation::Vertical, + *y - y_from - 1, + ); + x_from += 1; } } - writeln!(f)?; - // write vertical bars x 2 - for _ in 0..2 { - write!(f, "{space:>row_idx_width$} │")?; - for ((col_i, col_name), col_width) in row.iter().enumerate().zip(&col_widths) { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; - } - - let half = (*col_width + col_spacing + 4) / 2; - let remaining = col_width + col_spacing + 4 - half - 1; - - // left_half - write!(f, "{space:^half$}")?; - // bar - let val = if col_name.is_empty() { ' ' } else { '│' }; - write!(f, "{}", val)?; + } + } +} - write!(f, "{space:^remaining$}")?; - } - writeln!(f)?; +/// The actual drawing happens in the conversion of the intermediate `TreeView` into `Canvas` +impl From> for Canvas { + fn from(value: TreeView<'_>) -> Self { + let width = value.n_rows_width + 3 + value.columns.iter().map(|c| c.width).sum::(); + let height = + 3 + value.rows.iter().map(|r| r.height).sum::() + 3 * (value.n_rows - 1); + let mut canvas = Canvas::new(width, height, Glyphs::default()); + + // Axles + let (x, y) = (value.n_rows_width + 2, 1); + canvas.draw_symbol(Point(x, y), '┌'); + canvas.draw_line(Point(x + 1, y), Orientation::Horizontal, width - x); + canvas.draw_line(Point(x, y + 1), Orientation::Vertical, height - y); + + // Row and column indices + for (i, row) in value.rows.iter().enumerate() { + // the prefix `Vec` of spaces compensates for the row indices that are shorter than the + // highest index, effectively, row indices are right-aligned + for (j, c) in vec![' '; value.n_rows_width - digits(i)] + .into_iter() + .chain(format!("{i}").chars()) + .enumerate() + { + canvas.draw_symbol(Point(j + 1, row.offset + row.center), c); } } - - // write the top of the boxes - write!(f, "{space:>row_idx_width$} │")?; - for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; + for (j, col) in value.columns.iter().enumerate() { + let j_width = digits(j); + let start = col.offset + col.center - (j_width / 2 + j_width % 2); + for (k, c) in format!("{j}").chars().enumerate() { + canvas.draw_symbol(Point(start + k, 0), c); } - let char_count = col_repr.chars().count() + 4; - let half = (*col_width + col_spacing + 4 - char_count) / 2; - let remaining = col_width + col_spacing + 4 - half - char_count; - - write!(f, "{space:^half$}")?; + } - if !col_repr.is_empty() { - write!(f, "╭")?; - write!(f, "{dash:─^width$}", width = char_count - 2)?; - write!(f, "╮")?; - } else { - write!(f, " ")?; + // Non-empty cells (nodes) and their connections (edges) + for (i, row) in value.matrix.iter().enumerate() { + for (j, cell) in row.iter().enumerate() { + if !cell.text.is_empty() { + canvas.draw_label_centered( + Point( + value.columns[j].offset + value.columns[j].center, + value.rows[i].offset + value.rows[i].center, + ), + &cell.text, + ); + } } - write!(f, "{space:^remaining$}")?; } - writeln!(f)?; - - // write column names and spacing - write!(f, "{row_count:>row_idx_width$} │")?; - for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; - } - let char_count = col_repr.chars().count() + 4; - let half = (*col_width + col_spacing + 4 - char_count) / 2; - let remaining = col_width + col_spacing + 4 - half - char_count; - write!(f, "{space:^half$}")?; - - if !col_repr.is_empty() { - write!(f, "│ {} │", col_repr)?; + fn even_odd(a: usize, b: usize) -> usize { + if a % 2 == 0 && b % 2 == 1 { + 1 } else { - write!(f, " ")?; + 0 } - write!(f, "{space:^remaining$}")?; } - writeln!(f)?; - - // write the bottom of the boxes - write!(f, "{space:>row_idx_width$} │")?; - for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { - let mut col_spacing = COL_SPACING; - if col_i > 0 { - col_spacing *= 2; - } - let char_count = col_repr.chars().count() + 4; - let half = (*col_width + col_spacing + 4 - char_count) / 2; - let remaining = col_width + col_spacing + 4 - half - char_count; - write!(f, "{space:^half$}")?; - - if !col_repr.is_empty() { - write!(f, "╰")?; - write!(f, "{dash:─^width$}", width = char_count - 2)?; - write!(f, "╯")?; - } else { - write!(f, " ")?; + for (i, row) in value.matrix.iter().enumerate() { + for (j, cell) in row.iter().enumerate() { + if !cell.text.is_empty() && i < value.rows.len() - 1 { + let children_points = cell + .children_columns + .iter() + .map(|k| { + let child_total_padding = + value.rows[i + 1].height - value.matrix[i + 1][*k].text.len() - 2; + let even_cell_in_odd_row = even_odd( + value.matrix[i + 1][*k].text.len(), + value.rows[i + 1].height, + ); + Point( + value.columns[*k].offset + value.columns[*k].center - 1, + value.rows[i + 1].offset + + child_total_padding / 2 + + child_total_padding % 2 + - even_cell_in_odd_row, + ) + }) + .collect::>(); + + let parent_total_padding = + value.rows[i].height - value.matrix[i][j].text.len() - 2; + let even_cell_in_odd_row = + even_odd(value.matrix[i][j].text.len(), value.rows[i].height); + + canvas.draw_connections( + Point( + value.columns[j].offset + value.columns[j].center - 1, + value.rows[i].offset + value.rows[i].height + - parent_total_padding / 2 + - even_cell_in_odd_row, + ), + &children_points, + parent_total_padding / 2 + 1 + even_cell_in_odd_row, + ); + } } - write!(f, "{space:^remaining$}")?; } - writeln!(f)?; + + canvas } +} - Ok(()) +impl Display for Canvas { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for row in &self.canvas { + writeln!(f, "{}", row.iter().collect::().trim_end())?; + } + + Ok(()) + } } impl Display for TreeFmtVisitor { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - format_levels(f, &self.levels) + Debug::fmt(self, f) + } +} + +impl Debug for TreeFmtVisitor { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let tree_view: TreeView<'_> = self.levels.as_slice().into(); + let canvas: Canvas = tree_view.into(); + write!(f, "{canvas}")?; + + Ok(()) } } @@ -328,7 +798,7 @@ mod test { let mut arena = Default::default(); let node = to_aexpr(e, &mut arena); - let mut visitor = TreeFmtVisitor::new(); + let mut visitor = TreeFmtVisitor::default(); AexprNode::with_context(node, &mut arena, |ae_node| ae_node.visit(&mut visitor)).unwrap(); let expected: &[&[&str]] = &[ @@ -341,4 +811,97 @@ mod test { assert_eq!(visitor.levels, expected); } + + #[test] + fn test_tree_format_levels() { + let e = (col("a") + col("b")).pow(2) + col("c") * col("d"); + let mut arena = Default::default(); + let node = to_aexpr(e, &mut arena); + + let mut visitor = TreeFmtVisitor::default(); + + AexprNode::with_context(node, &mut arena, |ae_node| ae_node.visit(&mut visitor)).unwrap(); + + let expected_lines = vec![ + " 0 1 2 3 4", + " ┌─────────────────────────────────────────────────────────────────────────", + " │", + " │ ╭───────────╮", + " 0 │ │ binary: + │", + " │ ╰─────┬┬────╯", + " │ ││", + " │ │╰───────────────────────────╮", + " │ │ │", + " │ ╭─────┴─────╮ ╭───────┴───────╮", + " 1 │ │ binary: * │ │ function: pow │", + " │ ╰─────┬┬────╯ ╰───────┬┬──────╯", + " │ ││ ││", + " │ │╰───────────╮ │╰───────────────╮", + " │ │ │ │ │", + " │ ╭───┴────╮ ╭───┴────╮ ╭───┴────╮ ╭─────┴─────╮", + " 2 │ │ col(d) │ │ col(c) │ │ lit(2) │ │ binary: + │", + " │ ╰────────╯ ╰────────╯ ╰────────╯ ╰─────┬┬────╯", + " │ ││", + " │ │╰───────────╮", + " │ │ │", + " │ ╭───┴────╮ ╭───┴────╮", + " 3 │ │ col(b) │ │ col(a) │", + " │ ╰────────╯ ╰────────╯", + ]; + for (i, (line, expected_line)) in + format!("{visitor}").lines().zip(expected_lines).enumerate() + { + assert_eq!(line, expected_line, "Difference at line {}", i + 1); + } + } + + #[cfg(feature = "range")] + #[test] + fn test_tree_format_levels_with_range() { + let e = (col("a") + col("b")).pow(2) + + int_range( + Expr::Literal(LiteralValue::Int64(0)), + Expr::Literal(LiteralValue::Int64(3)), + 1, + polars_core::datatypes::DataType::Int64, + ); + let mut arena = Default::default(); + let node = to_aexpr(e, &mut arena); + + let mut visitor = TreeFmtVisitor::default(); + + AexprNode::with_context(node, &mut arena, |ae_node| ae_node.visit(&mut visitor)).unwrap(); + + let expected_lines = vec![ + " 0 1 2 3 4", + " ┌───────────────────────────────────────────────────────────────────────────────────", + " │", + " │ ╭───────────╮", + " 0 │ │ binary: + │", + " │ ╰─────┬┬────╯", + " │ ││", + " │ │╰────────────────────────────────╮", + " │ │ │", + " │ ╭──────────┴──────────╮ ╭───────┴───────╮", + " 1 │ │ function: int_range │ │ function: pow │", + " │ ╰──────────┬┬─────────╯ ╰───────┬┬──────╯", + " │ ││ ││", + " │ │╰────────────────╮ │╰───────────────╮", + " │ │ │ │ │", + " │ ╭───┴────╮ ╭───┴────╮ ╭───┴────╮ ╭─────┴─────╮", + " 2 │ │ lit(3) │ │ lit(0) │ │ lit(2) │ │ binary: + │", + " │ ╰────────╯ ╰────────╯ ╰────────╯ ╰─────┬┬────╯", + " │ ││", + " │ │╰───────────╮", + " │ │ │", + " │ ╭───┴────╮ ╭───┴────╮", + " 3 │ │ col(b) │ │ col(a) │", + " │ ╰────────╯ ╰────────╯", + ]; + for (i, (line, expected_line)) in + format!("{visitor}").lines().zip(expected_lines).enumerate() + { + assert_eq!(line, expected_line, "Difference at line {}", i + 1); + } + } } diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 36a7f3032546..74f48f60592e 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -1,4 +1,5 @@ use polars_core::prelude::{Field, Schema}; +use polars_utils::unitvec; use super::*; use crate::prelude::*; @@ -8,11 +9,11 @@ impl TreeWalker for Expr { &'a self, op: &mut dyn FnMut(&Self) -> PolarsResult, ) -> PolarsResult { - let mut scratch = vec![]; + let mut scratch = unitvec![]; self.nodes(&mut scratch); - for child in scratch { + for &child in scratch.as_slice() { match op(child)? { // let the recursion continue VisitRecursion::Continue | VisitRecursion::Skip => {}, @@ -54,7 +55,7 @@ impl AexprNode { where F: FnOnce(AexprNode) -> T, { - // safety: we drop this context before arena is out of scope + // SAFETY: we drop this context before arena is out of scope unsafe { op(Self::new(node, arena)) } } @@ -63,7 +64,7 @@ impl AexprNode { where F: FnOnce(AexprNode, &mut Arena) -> T, { - // safety: we drop this context before arena is out of scope + // SAFETY: we drop this context before arena is out of scope unsafe { op(Self::new(node, arena), arena) } } @@ -185,7 +186,7 @@ impl AexprNode { loop { match (scratch1.pop(), scratch2.pop()) { (Some(l), Some(r)) => { - // safety: we can pass a *mut pointer + // SAFETY: we can pass a *mut pointer // the equality operation will not access mutable let l = unsafe { AexprNode::from_raw(l, self.arena) }; let r = unsafe { AexprNode::from_raw(r, self.arena) }; diff --git a/crates/polars-plan/src/logical_plan/visitor/lp.rs b/crates/polars-plan/src/logical_plan/visitor/lp.rs index b13457ba1acb..b8f2197d169d 100644 --- a/crates/polars-plan/src/logical_plan/visitor/lp.rs +++ b/crates/polars-plan/src/logical_plan/visitor/lp.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use polars_core::schema::SchemaRef; +use polars_utils::unitvec; use super::*; use crate::prelude::*; @@ -30,7 +31,7 @@ impl ALogicalPlanNode { where F: FnMut(ALogicalPlanNode) -> T, { - // safety: we drop this context before arena is out of scope + // SAFETY: we drop this context before arena is out of scope unsafe { op(Self::new(node, arena)) } } @@ -98,10 +99,10 @@ impl TreeWalker for ALogicalPlanNode { &'a self, op: &mut dyn FnMut(&Self) -> PolarsResult, ) -> PolarsResult { - let mut scratch = vec![]; + let mut scratch = unitvec![]; self.to_alp().copy_inputs(&mut scratch); - for node in scratch { + for &node in scratch.as_slice() { let lp_node = ALogicalPlanNode { node, arena: self.arena, diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index 85da66d68b61..0b3f37cdfb22 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -9,11 +9,6 @@ pub(crate) use polars_time::in_nanoseconds_window; feature = "dtype-time" ))] pub(crate) use polars_time::prelude::*; -#[cfg(feature = "rolling_window")] -pub(crate) use polars_time::{ - chunkedarray::{RollingOptions, RollingOptionsImpl}, - Duration, -}; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 69787ff15467..6c8a7b67421a 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -1,12 +1,10 @@ use std::fmt::Formatter; use std::iter::FlatMap; -use std::sync::Arc; use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; use smartstring::alias::String as SmartString; -use crate::logical_plan::iterator::ArenaExprIter; -use crate::logical_plan::Context; use crate::prelude::consts::{LEN, LITERAL_NAME}; use crate::prelude::*; @@ -42,30 +40,27 @@ pub(crate) fn fmt_column_delimited>( pub trait PushNode { fn push_node(&mut self, value: Node); + + fn extend_from_slice(&mut self, values: &[Node]); } impl PushNode for Vec { fn push_node(&mut self, value: Node) { self.push(value) } -} -impl PushNode for [Option; 2] { - fn push_node(&mut self, value: Node) { - match self { - [None, None] => self[0] = Some(value), - [Some(_), None] => self[1] = Some(value), - _ => panic!("cannot push more than 2 nodes"), - } + fn extend_from_slice(&mut self, values: &[Node]) { + Vec::extend_from_slice(self, values) } } -impl PushNode for [Option; 1] { +impl PushNode for UnitVec { fn push_node(&mut self, value: Node) { - match self { - [None] => self[0] = Some(value), - _ => panic!("cannot push more than 1 node"), - } + self.push(value) + } + + fn extend_from_slice(&mut self, values: &[Node]) { + UnitVec::extend(self, values.iter().copied()) } } @@ -76,16 +71,6 @@ pub(crate) fn is_scan(plan: &ALogicalPlan) -> bool { ) } -impl PushNode for &mut [Option] { - fn push_node(&mut self, value: Node) { - if self[0].is_some() { - self[1] = Some(value) - } else { - self[0] = Some(value) - } - } -} - /// A projection that only takes a column or a column + alias. #[cfg(feature = "meta")] pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena) -> bool { @@ -94,22 +79,17 @@ pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena) -> bool { - arena.iter(current_node).all(|(_node, e)| { - use AExpr::*; - match e { - AnonymousFunction { options, .. } | Function { options, .. } => { - !matches!(options.collect_groups, ApplyOptions::GroupWise) - }, - Column(_) - | Alias(_, _) - | Literal(_) - | BinaryExpr { .. } - | Ternary { .. } - | Cast { .. } => true, - _ => false, - } - }) +pub(crate) fn single_aexpr_is_elementwise(ae: &AExpr) -> bool { + use AExpr::*; + match ae { + AnonymousFunction { options, .. } | Function { options, .. } => { + !matches!(options.collect_groups, ApplyOptions::GroupWise) + }, + Column(_) | Alias(_, _) | Literal(_) | BinaryExpr { .. } | Ternary { .. } | Cast { .. } => { + true + }, + _ => false, + } } pub fn has_aexpr(current_node: Node, arena: &Arena, matches: F) -> bool @@ -128,7 +108,7 @@ pub fn has_aexpr_literal(current_node: Node, arena: &Arena) -> bool { } /// Can check if an expression tree has a matching_expr. This -/// requires a dummy expression to be created that will be used to patter match against. +/// requires a dummy expression to be created that will be used to pattern match against. pub(crate) fn has_expr(current_expr: &Expr, matches: F) -> bool where F: Fn(&Expr) -> bool, diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index f4e30e738b30..8ed9a95cb5e8 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -69,20 +69,20 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( let values_size = allocate_rows_buf(&flattened_columns, &mut rows.values, &mut rows.offsets); for (arr, field) in flattened_columns.iter().zip(flattened_fields.iter()) { - // Safety: + // SAFETY: // we allocated rows with enough bytes. unsafe { encode_array(&**arr, field, rows) } } - // safety: values are initialized + // SAFETY: values are initialized unsafe { rows.values.set_len(values_size) } } else { let values_size = allocate_rows_buf(columns, &mut rows.values, &mut rows.offsets); for (arr, field) in columns.iter().zip(fields) { - // Safety: + // SAFETY: // we allocated rows with enough bytes. unsafe { encode_array(&**arr, field, rows) } } - // safety: values are initialized + // SAFETY: values are initialized unsafe { rows.values.set_len(values_size) } } } @@ -170,7 +170,9 @@ pub fn allocate_rows_buf( let has_variable = columns.iter().any(|arr| { matches!( arr.data_type(), - ArrowDataType::BinaryView | ArrowDataType::Dictionary(_, _, _) + ArrowDataType::BinaryView + | ArrowDataType::Dictionary(_, _, _) + | ArrowDataType::LargeBinary ) }); @@ -183,7 +185,9 @@ pub fn allocate_rows_buf( .map(|arr| { if matches!( arr.data_type(), - ArrowDataType::BinaryView | ArrowDataType::Dictionary(_, _, _) + ArrowDataType::BinaryView + | ArrowDataType::Dictionary(_, _, _) + | ArrowDataType::LargeBinary ) { 0 } else { @@ -219,6 +223,23 @@ pub fn allocate_rows_buf( } processed_count += 1; }, + ArrowDataType::LargeBinary => { + let array = array.as_any().downcast_ref::>().unwrap(); + if processed_count == 0 { + for opt_val in array.into_iter() { + unsafe { + lengths.push_unchecked( + row_size_fixed + crate::variable::encoded_len(opt_val), + ); + } + } + } else { + for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { + *row_length += crate::variable::encoded_len(opt_val) + } + } + processed_count += 1; + }, ArrowDataType::Dictionary(_, _, _) => { let array = array .as_any() diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index 3b0560fc0a10..5916c4e3e68a 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -12,7 +12,6 @@ use crate::row::{RowsEncoded, SortField}; pub(crate) trait FromSlice { fn from_slice(slice: &[u8]) -> Self; - fn from_slice_inverted(slice: &[u8]) -> Self; } impl FromSlice for [u8; N] { @@ -20,10 +19,6 @@ impl FromSlice for [u8; N] { fn from_slice(slice: &[u8]) -> Self { slice.try_into().unwrap() } - - fn from_slice_inverted(_slice: &[u8]) -> Self { - todo!() - } } /// Encodes a value of a particular fixed width type into bytes diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index 9e17d86d400c..6752d065228f 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -33,10 +33,10 @@ fn checks(offsets: &[usize]) { unsafe fn rows_to_array(buf: Vec, offsets: Vec) -> BinaryArray { checks(&offsets); - // Safety: we checked overflow + // SAFETY: we checked overflow let offsets = std::mem::transmute::, Vec>(offsets); - // Safety: monotonically increasing + // SAFETY: monotonically increasing let offsets = Offsets::new_unchecked(offsets); BinaryArray::new(ArrowDataType::LargeBinary, offsets.into(), buf.into(), None) diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 162eec9b5277..1f2d32413563 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -31,7 +31,7 @@ default = [] nightly = [] csv = ["polars-lazy/csv"] ipc = ["polars-lazy/ipc"] -json = ["polars-lazy/json"] +json = ["polars-lazy/json", "polars-plan/extract_jsonpath"] binary_encoding = ["polars-lazy/binary_encoding"] diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 8e67535e6942..4e84219c9963 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -5,7 +5,6 @@ use polars_core::prelude::*; use polars_error::to_compute_err; use polars_lazy::prelude::*; use polars_plan::prelude::*; -use polars_plan::utils::expressions_to_schema; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 0142daf7433e..4d5c74982b1d 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -3,10 +3,8 @@ use std::ops::Div; use polars_core::export::regex; use polars_core::prelude::*; use polars_error::to_compute_err; -use polars_lazy::dsl::Expr; use polars_lazy::prelude::*; use polars_plan::prelude::LiteralValue::Null; -use polars_plan::prelude::{col, lit, when}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; #[cfg(feature = "dtype-decimal")] @@ -546,7 +544,7 @@ impl SQLExprVisitor<'_> { } /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr. - fn visit_anyvalue( + fn visit_any_value( &self, value: &SQLValue, op: Option<&UnaryOperator>, @@ -669,12 +667,12 @@ impl SQLExprVisitor<'_> { .iter() .map(|e| { if let SQLExpr::Value(v) = e { - let av = self.visit_anyvalue(v, None)?; + let av = self.visit_any_value(v, None)?; Ok(av) } else if let SQLExpr::UnaryOp {op, expr} = e { match expr.as_ref() { SQLExpr::Value(v) => { - let av = self.visit_anyvalue(v, Some(op))?; + let av = self.visit_any_value(v, Some(op))?; Ok(av) }, _ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)) @@ -873,6 +871,8 @@ fn process_join_on( polars_bail!(InvalidOperation: "SQL join clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); }, } + } else if let SQLExpr::Nested(expr) = expression { + process_join_on(expr, left_name, right_name) } else { polars_bail!(InvalidOperation: "SQL join clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); } diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 5fbbcb1b8336..712df6873d90 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -353,6 +353,52 @@ fn test_compound_join_nested_and() { ); } +#[test] +fn test_compound_join_nested_and_with_brackets() { + let df1 = df! { + "a" => [1, 2, 3, 4, 5], + "b" => [1, 2, 3, 4, 5], + "c" => [0, 3, 4, 5, 6], + "d" => [0, 3, 4, 5, 6], + } + .unwrap(); + let df2 = df! { + "a" => [1, 2, 3, 4, 5], + "b" => [1, 3, 3, 5, 6], + "c" => [0, 3, 4, 5, 6], + "d" => [0, 3, 4, 5, 6] + } + .unwrap(); + let mut ctx = SQLContext::new(); + ctx.register("df1", df1.lazy()); + ctx.register("df2", df2.lazy()); + + let sql = r#" + SELECT * FROM df1 + INNER JOIN df2 ON + df1.a = df2.a AND + ((df1.b = df2.b AND + df1.c = df2.c) AND + df1.d = df2.d) + "#; + let actual = ctx.execute(sql).unwrap().collect().unwrap(); + + let expected = df! { + "a" => [1, 3], + "b" => [1, 3], + "c" => [0, 4], + "d" => [0, 4], + } + .unwrap(); + + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); +} + #[test] #[should_panic] fn test_compound_invalid_1() { diff --git a/crates/polars-sql/tests/udf.rs b/crates/polars-sql/tests/udf.rs index fd52b63eea33..18b55629fb07 100644 --- a/crates/polars-sql/tests/udf.rs +++ b/crates/polars-sql/tests/udf.rs @@ -1,8 +1,4 @@ -use std::sync::Arc; - -use polars_core::prelude::{DataType, Field, *}; -use polars_core::series::Series; -use polars_error::PolarsResult; +use polars_core::prelude::*; use polars_lazy::prelude::IntoLazy; use polars_plan::prelude::{GetOutput, UserDefinedFunction}; use polars_sql::function_registry::FunctionRegistry; diff --git a/crates/polars-time/src/chunkedarray/datetime.rs b/crates/polars-time/src/chunkedarray/datetime.rs index 9ce60778eba2..f0111be9047f 100644 --- a/crates/polars-time/src/chunkedarray/datetime.rs +++ b/crates/polars-time/src/chunkedarray/datetime.rs @@ -1,4 +1,3 @@ -use arrow; use arrow::array::{Array, PrimitiveArray}; use arrow::compute::cast::{cast, CastOptions}; use arrow::compute::temporal; @@ -161,8 +160,6 @@ impl DatetimeMethods for DatetimeChunked {} #[cfg(test)] mod test { - use chrono::NaiveDateTime; - use super::*; #[test] diff --git a/crates/polars-time/src/chunkedarray/duration.rs b/crates/polars-time/src/chunkedarray/duration.rs index f64acfbb2afd..d4a590f9ee44 100644 --- a/crates/polars-time/src/chunkedarray/duration.rs +++ b/crates/polars-time/src/chunkedarray/duration.rs @@ -34,36 +34,46 @@ impl DurationMethods for DurationChunked { /// Extract the hours from a `Duration` fn hours(&self) -> Int64Chunked { match self.time_unit() { - TimeUnit::Milliseconds => &self.0 / (MILLISECONDS * SECONDS_IN_HOUR), - TimeUnit::Microseconds => &self.0 / (MICROSECONDS * SECONDS_IN_HOUR), - TimeUnit::Nanoseconds => &self.0 / (NANOSECONDS * SECONDS_IN_HOUR), + TimeUnit::Milliseconds => { + (&self.0).wrapping_trunc_div_scalar(MILLISECONDS * SECONDS_IN_HOUR) + }, + TimeUnit::Microseconds => { + (&self.0).wrapping_trunc_div_scalar(MICROSECONDS * SECONDS_IN_HOUR) + }, + TimeUnit::Nanoseconds => { + (&self.0).wrapping_trunc_div_scalar(NANOSECONDS * SECONDS_IN_HOUR) + }, } } /// Extract the days from a `Duration` fn days(&self) -> Int64Chunked { match self.time_unit() { - TimeUnit::Milliseconds => &self.0 / MILLISECONDS_IN_DAY, - TimeUnit::Microseconds => &self.0 / (MICROSECONDS * SECONDS_IN_DAY), - TimeUnit::Nanoseconds => &self.0 / (NANOSECONDS * SECONDS_IN_DAY), + TimeUnit::Milliseconds => (&self.0).wrapping_trunc_div_scalar(MILLISECONDS_IN_DAY), + TimeUnit::Microseconds => { + (&self.0).wrapping_trunc_div_scalar(MICROSECONDS * SECONDS_IN_DAY) + }, + TimeUnit::Nanoseconds => { + (&self.0).wrapping_trunc_div_scalar(NANOSECONDS * SECONDS_IN_DAY) + }, } } /// Extract the seconds from a `Duration` fn minutes(&self) -> Int64Chunked { match self.time_unit() { - TimeUnit::Milliseconds => &self.0 / (MILLISECONDS * 60), - TimeUnit::Microseconds => &self.0 / (MICROSECONDS * 60), - TimeUnit::Nanoseconds => &self.0 / (NANOSECONDS * 60), + TimeUnit::Milliseconds => (&self.0).wrapping_trunc_div_scalar(MILLISECONDS * 60), + TimeUnit::Microseconds => (&self.0).wrapping_trunc_div_scalar(MICROSECONDS * 60), + TimeUnit::Nanoseconds => (&self.0).wrapping_trunc_div_scalar(NANOSECONDS * 60), } } /// Extract the seconds from a `Duration` fn seconds(&self) -> Int64Chunked { match self.time_unit() { - TimeUnit::Milliseconds => &self.0 / MILLISECONDS, - TimeUnit::Microseconds => &self.0 / MICROSECONDS, - TimeUnit::Nanoseconds => &self.0 / NANOSECONDS, + TimeUnit::Milliseconds => (&self.0).wrapping_trunc_div_scalar(MILLISECONDS), + TimeUnit::Microseconds => (&self.0).wrapping_trunc_div_scalar(MICROSECONDS), + TimeUnit::Nanoseconds => (&self.0).wrapping_trunc_div_scalar(NANOSECONDS), } } @@ -71,8 +81,10 @@ impl DurationMethods for DurationChunked { fn milliseconds(&self) -> Int64Chunked { match self.time_unit() { TimeUnit::Milliseconds => self.0.clone(), - TimeUnit::Microseconds => self.0.clone() / 1000, - TimeUnit::Nanoseconds => &self.0 / NANOSECONDS_IN_MILLISECOND, + TimeUnit::Microseconds => self.0.clone().wrapping_trunc_div_scalar(1000), + TimeUnit::Nanoseconds => { + (&self.0).wrapping_trunc_div_scalar(NANOSECONDS_IN_MILLISECOND) + }, } } @@ -81,7 +93,7 @@ impl DurationMethods for DurationChunked { match self.time_unit() { TimeUnit::Milliseconds => &self.0 * 1000, TimeUnit::Microseconds => self.0.clone(), - TimeUnit::Nanoseconds => &self.0 / 1000, + TimeUnit::Nanoseconds => (&self.0).wrapping_trunc_div_scalar(1000), } } diff --git a/crates/polars-time/src/chunkedarray/kernels.rs b/crates/polars-time/src/chunkedarray/kernels.rs index 5e453925f8f6..8526c180ad1a 100644 --- a/crates/polars-time/src/chunkedarray/kernels.rs +++ b/crates/polars-time/src/chunkedarray/kernels.rs @@ -1,6 +1,6 @@ //! macros that define kernels for extracting //! `week`, `weekday`, `year`, `hour` etc. from primitive arrays. -use arrow::array::{ArrayRef, BooleanArray, PrimitiveArray}; +use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::compute::arity::unary; #[cfg(feature = "dtype-time")] use arrow::temporal_conversions::time64ns_to_time_opt; @@ -8,7 +8,7 @@ use arrow::temporal_conversions::{ date32_to_datetime_opt, timestamp_ms_to_datetime_opt, timestamp_ns_to_datetime_opt, timestamp_us_to_datetime_opt, }; -use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike}; +use chrono::{Datelike, Timelike}; use super::super::windows::calendar::*; use super::*; diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index 205233b15885..5eb03d2bcc0f 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -1,8 +1,6 @@ mod dispatch; mod rolling_kernels; -use std::convert::TryFrom; - use arrow::array::{Array, ArrayRef, PrimitiveArray}; use arrow::legacy::kernels::rolling; pub use dispatch::*; diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index f77ef5f284cb..8c02b38625a7 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -39,7 +39,7 @@ where if len < (min_periods as IdxSize) { None } else { - // safety: + // SAFETY: // we are in bounds Some(unsafe { agg_window.update(start as usize, end as usize) }) } diff --git a/crates/polars-time/src/chunkedarray/string/strptime.rs b/crates/polars-time/src/chunkedarray/string/strptime.rs index abe6c2b3df8c..c161f3bb27de 100644 --- a/crates/polars-time/src/chunkedarray/string/strptime.rs +++ b/crates/polars-time/src/chunkedarray/string/strptime.rs @@ -120,7 +120,7 @@ impl StrpTimeState { debug_assert!(offset < val.len()); let b = *val.get_unchecked(offset); if *fmt_b == ESCAPE { - // Safety: we must ensure we provide valid patterns + // SAFETY: we must ensure we provide valid patterns let next = fmt_iter.next(); debug_assert!(next.is_some()); match next.unwrap_unchecked() { @@ -209,7 +209,7 @@ pub(super) fn fmt_len(fmt: &[u8]) -> Option { while let Some(&val) = iter.next() { match val { - b'%' => match iter.next().expect("invalid patter") { + b'%' => match iter.next().expect("invalid pattern") { b'Y' => cnt += 4, b'y' => cnt += 2, b'd' => cnt += 2, diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index e9d72b01d249..2b1df2fd19d7 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -1,7 +1,6 @@ use arrow::legacy::time_zone::Tz; use arrow::legacy::utils::CustomIterTools; use polars_core::export::rayon::prelude::*; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::ensure_sorted_arg; @@ -151,18 +150,18 @@ impl Wrap<&DataFrame> { TimeUnit::Milliseconds, None, ), - Int32 => { - let time_type = Datetime(TimeUnit::Nanoseconds, None); - let dt = time.cast(&Int64).unwrap().cast(&time_type).unwrap(); + UInt32 | UInt64 | Int32 => { + let time_type_dt = Datetime(TimeUnit::Nanoseconds, None); + let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap(); let (out, by, gt) = self.impl_group_by_rolling( dt, by, options, TimeUnit::Nanoseconds, None, - &time_type, + &time_type_dt, )?; - let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap(); + let out = out.cast(&Int64).unwrap().cast(time_type).unwrap(); return Ok((out, by, gt)); }, Int64 => { @@ -661,7 +660,7 @@ fn update_subgroups_idx( mod test { use chrono::prelude::*; use polars_ops::prelude::*; - use polars_utils::idxvec; + use polars_utils::unitvec; use super::*; @@ -899,12 +898,12 @@ mod test { let expected = GroupsProxy::Idx( vec![ - (0 as IdxSize, idxvec![0 as IdxSize, 1, 2]), - (2, idxvec![2]), - (5, idxvec![5, 6]), - (6, idxvec![6]), - (3, idxvec![3, 4]), - (4, idxvec![4]), + (0 as IdxSize, unitvec![0 as IdxSize, 1, 2]), + (2, unitvec![2]), + (5, unitvec![5, 6]), + (6, unitvec![6]), + (3, unitvec![3, 4]), + (4, unitvec![4]), ] .into(), ); diff --git a/crates/polars-time/src/series/mod.rs b/crates/polars-time/src/series/mod.rs index 702297362c69..c3fbf9aac635 100644 --- a/crates/polars-time/src/series/mod.rs +++ b/crates/polars-time/src/series/mod.rs @@ -255,7 +255,7 @@ pub trait TemporalMethods: AsSeries { /// Convert date(time) object to timestamp in [`TimeUnit`]. fn timestamp(&self, tu: TimeUnit) -> PolarsResult { let s = self.as_series(); - if matches!(s.dtype(), DataType::Time) { + if matches!(s.dtype(), DataType::Time | DataType::Duration(_)) { polars_bail!(opq = timestamp, s.dtype()); } else { s.cast(&DataType::Datetime(tu, None)) diff --git a/crates/polars-time/src/upsample.rs b/crates/polars-time/src/upsample.rs index a7af41238cc5..ff010dfa2429 100644 --- a/crates/polars-time/src/upsample.rs +++ b/crates/polars-time/src/upsample.rs @@ -46,7 +46,10 @@ pub trait PolarsUpsample { offset: Duration, ) -> PolarsResult; - /// Upsample a DataFrame at a regular frequency. + /// Upsample a [`DataFrame`] at a regular frequency. + /// + /// Similar to [`upsample`][PolarsUpsample::upsample], but order of the + /// DataFrame is maintained when `by` is specified. /// /// # Arguments /// * `by` - First group by these columns and then upsample for every group @@ -160,8 +163,8 @@ fn upsample_single_impl( Datetime(tu, tz) => { let s = index_column.cast(&Int64).unwrap(); let ca = s.i64().unwrap(); - let first = ca.into_iter().flatten().next(); - let last = ca.into_iter().flatten().next_back(); + let first = ca.iter().flatten().next(); + let last = ca.iter().flatten().next_back(); match (first, last) { (Some(first), Some(last)) => { let tz = match tz { diff --git a/crates/polars-time/src/windows/test.rs b/crates/polars-time/src/windows/test.rs index 6fc9c663efe2..d84eb921e320 100644 --- a/crates/polars-time/src/windows/test.rs +++ b/crates/polars-time/src/windows/test.rs @@ -2,7 +2,6 @@ use arrow::temporal_conversions::timestamp_ns_to_datetime; use chrono::prelude::*; use polars_core::prelude::*; -use crate::date_range::datetime_range_i64; use crate::prelude::*; #[test] diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 3666013a18e3..8adb7520ecfe 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -5,7 +5,6 @@ use chrono::NaiveDateTime; use chrono::TimeZone; use now::DateTimeNow; use polars_core::prelude::*; -use polars_core::utils::arrow::temporal_conversions::timeunit_scale; use crate::prelude::*; diff --git a/crates/polars-utils/src/floor_divmod.rs b/crates/polars-utils/src/floor_divmod.rs new file mode 100644 index 000000000000..14c02fc8d257 --- /dev/null +++ b/crates/polars-utils/src/floor_divmod.rs @@ -0,0 +1,102 @@ +pub trait FloorDivMod: Sized { + // Returns the flooring division and associated modulo of lhs / rhs. + // This is the same division / modulo combination as Python. + // + // Returns (0, 0) if other == 0. + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self); +} + +macro_rules! impl_float_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + let div = (self / other).floor(); + let mod_ = self - other * div; + (div, mod_) + } + } + }; +} + +macro_rules! impl_unsigned_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + (self / other, self % other) + } + } + }; +} + +macro_rules! impl_signed_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + if other == 0 { + return (0, 0); + } + + // Rust/C-style remainder is in the correct congruence + // class, but may not have the right sign. We want a + // remainder with the same sign as the RHS, which we + // can get by adding RHS to the remainder if the sign of + // the non-zero remainder differs from our RHS. + // + // Similarly, Rust/C-style division truncates instead of floors. + // If the remainder was non-zero and the signs were different + // (we'd have a negative result before truncating), we need to + // subtract 1 from the result. + let mut div = self.wrapping_div(other); + let mut mod_ = self.wrapping_rem(other); + if mod_ != 0 && (self < 0) != (other < 0) { + div -= 1; + mod_ += other; + } + (div, mod_) + } + } + }; +} + +impl_unsigned_div_mod!(u8); +impl_unsigned_div_mod!(u16); +impl_unsigned_div_mod!(u32); +impl_unsigned_div_mod!(u64); +impl_unsigned_div_mod!(u128); +impl_unsigned_div_mod!(usize); +impl_signed_div_mod!(i8); +impl_signed_div_mod!(i16); +impl_signed_div_mod!(i32); +impl_signed_div_mod!(i64); +impl_signed_div_mod!(i128); +impl_signed_div_mod!(isize); +impl_float_div_mod!(f32); +impl_float_div_mod!(f64); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_signed_wrapping_div_mod() { + // Test for all i8, should transfer to other values. + for lhs in i8::MIN..=i8::MAX { + for rhs in i8::MIN..=i8::MAX { + let ans = if rhs != 0 { + let fdiv = (lhs as f64 / rhs as f64).floor(); + let fmod = lhs as f64 - rhs as f64 * fdiv; + + // float -> int conversion saturates, we want wrapping, double convert. + ((fdiv as i32) as i8, (fmod as i32) as i8) + } else { + (0, 0) + }; + + assert_eq!(lhs.wrapping_floor_div_mod(rhs), ans); + } + } + } +} diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index 0da31836c69c..c0f7098a207d 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -4,24 +4,26 @@ use std::ops::Deref; use crate::IdxSize; -/// A type logically equivalent to `Vec`, but which does not do a +pub type IdxVec = UnitVec; + +/// A type logically equivalent to `Vec`, but which does not do a /// memory allocation until at least two elements have been pushed, storing the /// first element in the data pointer directly. #[derive(Eq)] -pub struct IdxVec { +pub struct UnitVec { len: usize, capacity: NonZeroUsize, - data: *mut IdxSize, + data: *mut T, } -unsafe impl Send for IdxVec {} -unsafe impl Sync for IdxVec {} +unsafe impl Send for UnitVec {} +unsafe impl Sync for UnitVec {} -impl IdxVec { +impl UnitVec { #[inline(always)] - fn data_ptr_mut(&mut self) -> *mut IdxSize { + fn data_ptr_mut(&mut self) -> *mut T { let external = self.data; - let inline = &mut self.data as *mut *mut IdxSize as *mut IdxSize; + let inline = &mut self.data as *mut *mut T as *mut T; if self.capacity.get() == 1 { inline } else { @@ -30,9 +32,9 @@ impl IdxVec { } #[inline(always)] - fn data_ptr(&self) -> *const IdxSize { + fn data_ptr(&self) -> *const T { let external = self.data; - let inline = &self.data as *const *mut IdxSize as *mut IdxSize; + let inline = &self.data as *const *mut T as *mut T; if self.capacity.get() == 1 { inline } else { @@ -40,7 +42,13 @@ impl IdxVec { } } + #[inline] pub fn new() -> Self { + // This is optimized away, all const. + assert!( + std::mem::size_of::() <= std::mem::size_of::<*mut T>() + && std::mem::align_of::() <= std::mem::align_of::<*mut T>() + ); Self { len: 0, capacity: NonZeroUsize::new(1).unwrap(), @@ -64,7 +72,7 @@ impl IdxVec { } #[inline(always)] - pub fn push(&mut self, idx: IdxSize) { + pub fn push(&mut self, idx: T) { if self.len == self.capacity.get() { self.reserve(1); } @@ -74,8 +82,8 @@ impl IdxVec { #[inline(always)] /// # Safety - /// Caller must ensure that `IdxVec` has enough capacity. - pub unsafe fn push_unchecked(&mut self, idx: IdxSize) { + /// Caller must ensure that `UnitVec` has enough capacity. + pub unsafe fn push_unchecked(&mut self, idx: T) { unsafe { self.data_ptr_mut().add(self.len).write(idx); self.len += 1; @@ -118,36 +126,62 @@ impl IdxVec { new } - pub fn iter(&self) -> std::slice::Iter<'_, IdxSize> { + #[inline] + pub fn iter(&self) -> std::slice::Iter<'_, T> { self.as_slice().iter() } - pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, IdxSize> { + #[inline] + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> { self.as_mut_slice().iter_mut() } - pub fn as_slice(&self) -> &[IdxSize] { + #[inline] + pub fn as_slice(&self) -> &[T] { self.as_ref() } - pub fn as_mut_slice(&mut self) -> &mut [IdxSize] { + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { self.as_mut() } + + #[inline] + pub fn pop(&mut self) -> Option { + if self.len == 0 { + None + } else { + unsafe { + self.len -= 1; + Some(std::ptr::read(self.as_ptr().add(self.len()))) + } + } + } } -impl Drop for IdxVec { +impl Extend for UnitVec { + fn extend>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for v in iter { + self.push(v) + } + } +} + +impl Drop for UnitVec { fn drop(&mut self) { self.dealloc() } } -impl Clone for IdxVec { +impl Clone for UnitVec { fn clone(&self) -> Self { unsafe { let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(self.len)); let buffer = me.as_mut_ptr(); std::ptr::copy(self.data_ptr(), buffer, self.len); - IdxVec { + UnitVec { data: buffer, len: self.len, capacity: NonZeroUsize::new(std::cmp::max(self.len, 1)).unwrap(), @@ -156,13 +190,13 @@ impl Clone for IdxVec { } } -impl Debug for IdxVec { +impl Debug for UnitVec { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "IdxVec: {:?}", self.as_slice()) + write!(f, "UnitVec: {:?}", self.as_slice()) } } -impl Default for IdxVec { +impl Default for UnitVec { fn default() -> Self { Self { len: 0, @@ -172,37 +206,37 @@ impl Default for IdxVec { } } -impl Deref for IdxVec { - type Target = [IdxSize]; +impl Deref for UnitVec { + type Target = [T]; fn deref(&self) -> &Self::Target { self.as_slice() } } -impl AsRef<[IdxSize]> for IdxVec { - fn as_ref(&self) -> &[IdxSize] { +impl AsRef<[T]> for UnitVec { + fn as_ref(&self) -> &[T] { unsafe { std::slice::from_raw_parts(self.data_ptr(), self.len) } } } -impl AsMut<[IdxSize]> for IdxVec { - fn as_mut(&mut self) -> &mut [IdxSize] { +impl AsMut<[T]> for UnitVec { + fn as_mut(&mut self) -> &mut [T] { unsafe { std::slice::from_raw_parts_mut(self.data_ptr_mut(), self.len) } } } -impl PartialEq for IdxVec { +impl PartialEq for UnitVec { fn eq(&self, other: &Self) -> bool { self.as_slice() == other.as_slice() } } -impl FromIterator for IdxVec { - fn from_iter>(iter: T) -> Self { +impl FromIterator for UnitVec { + fn from_iter>(iter: I) -> Self { let iter = iter.into_iter(); if iter.size_hint().0 <= 1 { - let mut new = IdxVec::new(); + let mut new = UnitVec::new(); for v in iter { new.push(v) } @@ -214,17 +248,17 @@ impl FromIterator for IdxVec { } } -impl From> for IdxVec { - fn from(value: Vec) -> Self { +impl From> for UnitVec { + fn from(mut value: Vec) -> Self { if value.capacity() <= 1 { - let mut new = IdxVec::new(); - if let Some(v) = value.first() { - new.push(*v) + let mut new = UnitVec::new(); + if let Some(v) = value.pop() { + new.push(v) } new } else { let mut me = std::mem::ManuallyDrop::new(value); - IdxVec { + UnitVec { data: me.as_mut_ptr(), capacity: NonZeroUsize::new(me.capacity()).unwrap(), len: me.len(), @@ -233,12 +267,12 @@ impl From> for IdxVec { } } -impl From<&[IdxSize]> for IdxVec { - fn from(value: &[IdxSize]) -> Self { +impl From<&[T]> for UnitVec { + fn from(value: &[T]) -> Self { if value.len() <= 1 { - let mut new = IdxVec::new(); + let mut new = UnitVec::new(); if let Some(v) = value.first() { - new.push(*v) + new.push(v.clone()) } new } else { @@ -248,19 +282,19 @@ impl From<&[IdxSize]> for IdxVec { } #[macro_export] -macro_rules! idxvec { +macro_rules! unitvec { () => ( - $crate::idx_vec::IdxVec::new() + $crate::idx_vec::UnitVec::new() ); ($elem:expr; $n:expr) => ( - let mut new = $crate::idx_vec::IdxVec::new(); + let mut new = $crate::idx_vec::UnitVec::new(); for _ in 0..$n { new.push($elem) } new ); ($elem:expr) => ( - {let mut new = $crate::idx_vec::IdxVec::new(); + {let mut new = $crate::idx_vec::UnitVec::new(); // SAFETY: first element always fits. unsafe { new.push_unchecked($elem) }; new} diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index 0dae23160a1e..1815860383f4 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -113,3 +113,51 @@ impl_to_idx!(i8, i16); impl_to_idx!(i16, i32); impl_to_idx!(i32, i64); impl_to_idx!(i64, i64); + +// Allows for 2^24 (~16M) chunks +// Leaves 2^40 (~1T) rows per chunk +const CHUNK_BITS: u64 = 24; + +#[derive(Clone, Copy, Debug)] +#[repr(C)] +pub struct ChunkId { + swizzled: u64, +} + +impl ChunkId { + #[inline(always)] + #[allow(clippy::unnecessary_cast)] + pub fn store(chunk: IdxSize, row: IdxSize) -> Self { + debug_assert!(chunk < !(u64::MAX << CHUNK_BITS) as IdxSize); + let swizzled = (row as u64) << CHUNK_BITS | chunk as u64; + + Self { swizzled } + } + + #[inline(always)] + #[allow(clippy::unnecessary_cast)] + pub fn extract(self) -> (IdxSize, IdxSize) { + let row = (self.swizzled >> CHUNK_BITS) as IdxSize; + + const MASK: IdxSize = IdxSize::MAX << CHUNK_BITS; + let chunk = (self.swizzled as IdxSize) & !MASK; + (chunk, row) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_chunk_idx() { + let chunk = 213908; + let row = 813457; + + let ci = ChunkId::store(chunk, row); + let (c, r) = ci.extract(); + + assert_eq!(c, chunk); + assert_eq!(r, row); + } +} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index ce293cab883c..75e1d290b2ef 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -6,6 +6,7 @@ pub mod cache; pub mod cell; pub mod contention_pool; mod error; +pub mod floor_divmod; pub mod functions; pub mod hashing; pub mod idx_vec; diff --git a/crates/polars-utils/src/ord.rs b/crates/polars-utils/src/ord.rs index 2ffecdea477b..266a68100cba 100644 --- a/crates/polars-utils/src/ord.rs +++ b/crates/polars-utils/src/ord.rs @@ -11,14 +11,14 @@ where // this branch should be optimized away for integers if T::is_float() { match (a.is_nan(), b.is_nan()) { - // safety: we checked nans + // SAFETY: we checked nans (false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() }, (true, true) => Ordering::Equal, (true, false) => Ordering::Less, (false, true) => Ordering::Greater, } } else { - // Safety: + // SAFETY: // all integers are Ord unsafe { a.partial_cmp(b).unwrap_unchecked() } } @@ -33,14 +33,14 @@ where // this branch should be optimized away for integers if T::is_float() { match (a.is_nan(), b.is_nan()) { - // safety: we checked nans + // SAFETY: we checked nans (false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() }, (true, true) => Ordering::Equal, (true, false) => Ordering::Greater, (false, true) => Ordering::Less, } } else { - // Safety: + // SAFETY: // all integers are Ord unsafe { a.partial_cmp(b).unwrap_unchecked() } } diff --git a/crates/polars-utils/src/sort.rs b/crates/polars-utils/src/sort.rs index ae6b740933a1..780dc39d1b9c 100644 --- a/crates/polars-utils/src/sort.rs +++ b/crates/polars-utils/src/sort.rs @@ -34,14 +34,14 @@ pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: & idx.par_chunks(chunk_size).for_each(|indices| { let ptr = ptr as *mut IdxSize; for (idx_val, idx_location) in indices { - // Safety: + // SAFETY: // idx_location is in bounds by invariant of this function // and we ensured we have at least `idx.len()` capacity *ptr.add(*idx_location as usize) = *idx_val; } }); }); - // Safety: + // SAFETY: // all elements are written out.set_len(idx.len()); } @@ -65,14 +65,14 @@ pub unsafe fn perfect_sort( idx.par_chunks(chunk_size).for_each(|indices| { let ptr = ptr as *mut IdxSize; for (idx_val, idx_location) in indices { - // Safety: + // SAFETY: // idx_location is in bounds by invariant of this function // and we ensured we have at least `idx.len()` capacity *ptr.add(*idx_location as usize) = *idx_val; } }); }); - // Safety: + // SAFETY: // all elements are written out.set_len(idx.len()); } diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index 8dac484d5d96..5b9af065779f 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -3,6 +3,9 @@ use std::hash::{Hash, Hasher}; use bytemuck::TransparentWrapper; +use crate::hashing::{BytesHash, DirtyHash}; +use crate::nulls::IsNull; + /// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to /// the same value. pub fn canonical_f32(x: f32) -> f32 { @@ -32,7 +35,7 @@ pub fn canonical_f64(x: f64) -> f64 { pub trait TotalEq { fn tot_eq(&self, other: &Self) -> bool; - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { !(self.tot_eq(other)) } @@ -43,22 +46,22 @@ pub trait TotalEq { pub trait TotalOrd: TotalEq { fn tot_cmp(&self, other: &Self) -> Ordering; - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { self.tot_cmp(other) == Ordering::Less } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { self.tot_cmp(other) == Ordering::Greater } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { self.tot_cmp(other) != Ordering::Greater } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { self.tot_cmp(other) != Ordering::Less } @@ -87,46 +90,46 @@ pub struct TotalOrdWrap(pub T); unsafe impl TransparentWrapper for TotalOrdWrap {} impl PartialOrd for TotalOrdWrap { - #[inline(always)] + #[inline] fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } - #[inline(always)] + #[inline] fn lt(&self, other: &Self) -> bool { self.0.tot_lt(&other.0) } - #[inline(always)] + #[inline] fn le(&self, other: &Self) -> bool { self.0.tot_le(&other.0) } - #[inline(always)] + #[inline] fn gt(&self, other: &Self) -> bool { self.0.tot_gt(&other.0) } - #[inline(always)] + #[inline] fn ge(&self, other: &Self) -> bool { self.0.tot_ge(&other.0) } } impl Ord for TotalOrdWrap { - #[inline(always)] + #[inline] fn cmp(&self, other: &Self) -> Ordering { self.0.tot_cmp(&other.0) } } impl PartialEq for TotalOrdWrap { - #[inline(always)] + #[inline] fn eq(&self, other: &Self) -> bool { self.0.tot_eq(&other.0) } - #[inline(always)] + #[inline] #[allow(clippy::partialeq_ne_impl)] fn ne(&self, other: &Self) -> bool { self.0.tot_ne(&other.0) @@ -136,12 +139,14 @@ impl PartialEq for TotalOrdWrap { impl Eq for TotalOrdWrap {} impl Hash for TotalOrdWrap { + #[inline] fn hash(&self, state: &mut H) { self.0.tot_hash(state); } } impl Clone for TotalOrdWrap { + #[inline] fn clone(&self) -> Self { Self(self.0.clone()) } @@ -149,48 +154,85 @@ impl Clone for TotalOrdWrap { impl Copy for TotalOrdWrap {} +impl IsNull for TotalOrdWrap { + const HAS_NULLS: bool = T::HAS_NULLS; + type Inner = T::Inner; + + #[inline] + fn is_null(&self) -> bool { + self.0.is_null() + } + + #[inline] + fn unwrap_inner(self) -> Self::Inner { + self.0.unwrap_inner() + } +} + +impl DirtyHash for f32 { + #[inline] + fn dirty_hash(&self) -> u64 { + canonical_f32(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for f64 { + #[inline] + fn dirty_hash(&self) -> u64 { + canonical_f64(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for TotalOrdWrap { + #[inline] + fn dirty_hash(&self) -> u64 { + self.0.dirty_hash() + } +} + macro_rules! impl_trivial_total { ($T: ty) => { impl TotalEq for $T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { self == other } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { self != other } } impl TotalOrd for $T { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { self.cmp(other) } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { self < other } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { self > other } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { self <= other } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { self >= other } } impl TotalHash for $T { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -224,7 +266,7 @@ impl_trivial_total!(String); macro_rules! impl_float_eq_ord { ($T:ty) => { impl TotalEq for $T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { if self.is_nan() { other.is_nan() @@ -235,7 +277,7 @@ macro_rules! impl_float_eq_ord { } impl TotalOrd for $T { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { if self.tot_lt(other) { Ordering::Less @@ -246,22 +288,22 @@ macro_rules! impl_float_eq_ord { } } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { !self.tot_ge(other) } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { other.tot_ge(self) } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { // We consider all NaNs equal, and NaN is the largest possible // value. Thus if self is NaN we always return true. Otherwise @@ -278,6 +320,7 @@ impl_float_eq_ord!(f32); impl_float_eq_ord!(f64); impl TotalHash for f32 { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -287,6 +330,7 @@ impl TotalHash for f32 { } impl TotalHash for f64 { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -297,7 +341,7 @@ impl TotalHash for f64 { // Blanket implementations. impl TotalEq for Option { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { match (self, other) { (None, None) => true, @@ -306,7 +350,7 @@ impl TotalEq for Option { } } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { match (self, other) { (None, None) => false, @@ -317,7 +361,7 @@ impl TotalEq for Option { } impl TotalOrd for Option { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { match (self, other) { (None, None) => Ordering::Equal, @@ -327,7 +371,7 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { match (self, other) { (None, Some(_)) => true, @@ -336,12 +380,12 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { match (self, other) { (Some(_), None) => false, @@ -350,7 +394,7 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { other.tot_le(self) } @@ -369,18 +413,19 @@ impl TotalHash for Option { } impl TotalEq for &T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { (*self).tot_eq(*other) } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { (*self).tot_ne(*other) } } impl TotalHash for &T { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -402,3 +447,159 @@ impl TotalOrd for (T, U) { .then_with(|| self.1.tot_cmp(&other.1)) } } + +impl<'a> TotalHash for BytesHash<'a> { + #[inline] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state) + } +} + +impl<'a> TotalEq for BytesHash<'a> { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +/// This elides creating a [`TotalOrdWrap`] for types that don't need it. +pub trait ToTotalOrd { + type TotalOrdItem; + type SourceItem; + + fn to_total_ord(&self) -> Self::TotalOrdItem; + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem; +} + +macro_rules! impl_to_total_ord_identity { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = $T; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + self.clone() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_identity!(bool); +impl_to_total_ord_identity!(u8); +impl_to_total_ord_identity!(u16); +impl_to_total_ord_identity!(u32); +impl_to_total_ord_identity!(u64); +impl_to_total_ord_identity!(u128); +impl_to_total_ord_identity!(usize); +impl_to_total_ord_identity!(i8); +impl_to_total_ord_identity!(i16); +impl_to_total_ord_identity!(i32); +impl_to_total_ord_identity!(i64); +impl_to_total_ord_identity!(i128); +impl_to_total_ord_identity!(isize); +impl_to_total_ord_identity!(char); +impl_to_total_ord_identity!(String); + +macro_rules! impl_to_total_ord_lifetimed_ref_identity { + ($T: ty) => { + impl<'a> ToTotalOrd for &'a $T { + type TotalOrdItem = &'a $T; + type SourceItem = &'a $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_lifetimed_ref_identity!(str); +impl_to_total_ord_lifetimed_ref_identity!([u8]); + +macro_rules! impl_to_total_ord_wrapped { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = TotalOrdWrap<$T>; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(self.clone()) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } + } + }; +} + +impl_to_total_ord_wrapped!(f32); +impl_to_total_ord_wrapped!(f64); + +/// This is safe without needing to map the option value to TotalOrdWrap, since +/// for example: +/// `TotalOrdWrap>` implements `Eq + Hash`, iff: +/// `Option` implements `TotalEq + TotalHash`, iff: +/// `T` implements `TotalEq + TotalHash` +impl ToTotalOrd for Option { + type TotalOrdItem = TotalOrdWrap>; + type SourceItem = Option; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + +impl ToTotalOrd for &T { + type TotalOrdItem = T::TotalOrdItem; + type SourceItem = T::SourceItem; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + (*self).to_total_ord() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + T::peel_total_ord(ord_item) + } +} + +impl<'a> ToTotalOrd for BytesHash<'a> { + type TotalOrdItem = BytesHash<'a>; + type SourceItem = BytesHash<'a>; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } +} diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 9f340dc9b5bd..b1dea59aef56 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -11,17 +11,31 @@ repository = { workspace = true } description = "DataFrame library based on Apache Arrow" [dependencies] +arrow = { workspace = true } polars-core = { workspace = true, features = ["algorithm_group_by"] } +polars-error = { workspace = true } polars-io = { workspace = true, optional = true } polars-lazy = { workspace = true, optional = true } polars-ops = { workspace = true, optional = true } +polars-parquet = { workspace = true } polars-plan = { workspace = true, optional = true } polars-sql = { workspace = true, optional = true } polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } [dev-dependencies] ahash = { workspace = true } +apache-avro = { version = "0.16", features = ["snappy"] } +avro-schema = { workspace = true, features = ["async"] } +either = { workspace = true } +ethnum = "1" +futures = { workspace = true } +# used to run formal property testing +proptest = { version = "1", default_features = false, features = ["std"] } rand = { workspace = true } +# used to test async readers +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { workspace = true, features = ["compat"] } [build-dependencies] version_check = { workspace = true } @@ -35,7 +49,7 @@ sql = ["polars-sql"] rows = ["polars-core/rows"] simd = ["polars-core/simd", "polars-io/simd", "polars-ops?/simd"] avx512 = ["polars-core/avx512"] -nightly = ["polars-core/nightly", "polars-ops?/nightly", "simd", "polars-lazy?/nightly", "polars-sql/nightly"] +nightly = ["polars-core/nightly", "polars-ops?/nightly", "simd", "polars-lazy?/nightly", "polars-sql?/nightly"] docs = ["polars-core/docs"] temporal = ["polars-core/temporal", "polars-lazy?/temporal", "polars-io/temporal", "polars-time"] random = ["polars-core/random", "polars-lazy?/random", "polars-ops/random"] @@ -116,7 +130,7 @@ asof_join = ["polars-core/asof_join", "polars-lazy?/asof_join", "polars-ops/asof bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] checked_arithmetic = ["polars-core/checked_arithmetic"] -chunked_ids = ["polars-lazy?/chunked_ids", "polars-core/chunked_ids", "polars-ops/chunked_ids"] +chunked_ids = ["polars-ops?/chunked_ids"] coalesce = ["polars-lazy?/coalesce"] concat_str = ["polars-lazy?/concat_str"] cov = ["polars-lazy/cov"] @@ -162,6 +176,7 @@ list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] list_sample = ["polars-lazy?/list_sample"] list_sets = ["polars-lazy?/list_sets"] list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +array_to_struct = ["polars-ops/array_to_struct", "polars-lazy?/array_to_struct"] log = ["polars-ops/log", "polars-lazy?/log"] merge_sorted = ["polars-lazy?/merge_sorted"] meta = ["polars-lazy?/meta"] @@ -274,11 +289,13 @@ dtype-array = [ ] dtype-i8 = [ "polars-core/dtype-i8", + "polars-io/dtype-i8", "polars-lazy?/dtype-i8", "polars-ops/dtype-i8", ] dtype-i16 = [ "polars-core/dtype-i16", + "polars-io/dtype-i16", "polars-lazy?/dtype-i16", "polars-ops/dtype-i16", ] @@ -291,11 +308,13 @@ dtype-decimal = [ ] dtype-u8 = [ "polars-core/dtype-u8", + "polars-io/dtype-u8", "polars-lazy?/dtype-u8", "polars-ops/dtype-u8", ] dtype-u16 = [ "polars-core/dtype-u16", + "polars-io/dtype-u16", "polars-lazy?/dtype-u16", "polars-ops/dtype-u16", ] diff --git a/crates/polars/src/docs/eager.rs b/crates/polars/src/docs/eager.rs index 2d07cb524166..a4ad82b91eee 100644 --- a/crates/polars/src/docs/eager.rs +++ b/crates/polars/src/docs/eager.rs @@ -227,7 +227,7 @@ //! ca.lt_eq(&ca); //! //! // use iterators -//! let a: BooleanChunked = ca.into_iter() +//! let a: BooleanChunked = ca.iter() //! .map(|opt_value| { //! match opt_value { //! Some(value) => value < 10, diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index a5c9f55e8b21..3f0e84b3f9b5 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -57,7 +57,7 @@ //! * [Performance](#performance-and-string-data) //! - [Custom allocator](#custom-allocator) //! * [Config](#config-with-env-vars) -//! * [User Guide](#user-guide) +//! * [User guide](#user-guide) //! //! ## Cookbooks //! See examples in the cookbooks: @@ -147,9 +147,7 @@ //! (Note that within an expression there may be more parallelization going on). //! //! Understanding polars expressions is most important when starting with the polars library. Read more -//! about them in the [User Guide](https://docs.pola.rs/user-guide/concepts/expressions). -//! Though the examples given there are in python. The expressions API is almost identical and the -//! the read should certainly be valuable to rust users as well. +//! about them in the [user guide](https://docs.pola.rs/user-guide/concepts/expressions). //! //! ### Eager //! Read more in the pages of the following data structures /traits. @@ -225,12 +223,11 @@ //! - `rows` - Create [`DataFrame`] from rows and extract rows from [`DataFrame`]s. //! And activates `pivot` and `transpose` operations //! - `asof_join` - Join ASOF, to join on nearest keys instead of exact equality match. -//! - `cross_join` - Create the cartesian product of two [`DataFrame`]s. +//! - `cross_join` - Create the Cartesian product of two [`DataFrame`]s. //! - `semi_anti_join` - SEMI and ANTI joins. //! - `group_by_list` - Allow group_by operation on keys of type List. //! - `row_hash` - Utility to hash [`DataFrame`] rows to [`UInt64Chunked`] //! - `diagonal_concat` - Concat diagonally thereby combining different schemas. -//! - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match //! - `dataframe_arithmetic` - Arithmetic on ([`Dataframe`] and [`DataFrame`]s) and ([`DataFrame`] on [`Series`]) //! - `partition_by` - Split into multiple [`DataFrame`]s partitioned by groups. //! * [`Series`]/[`Expr`] operations: @@ -407,8 +404,9 @@ //! * `POLARS_PANIC_ON_ERR` -> panic instead of returning an Error. //! * `POLARS_NO_CHUNKED_JOIN` -> force rechunk before joins. //! -//! ## User Guide -//! If you want to read more, [check the User Guide](https://docs.pola.rs/). +//! ## User guide +//! +//! If you want to read more, check the [user guide](https://docs.pola.rs/). #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow(ambiguous_glob_reexports)] pub mod docs; diff --git a/crates/polars/tests/it/arrow/array/binary/mod.rs b/crates/polars/tests/it/arrow/array/binary/mod.rs new file mode 100644 index 000000000000..3a44b67cbca9 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mod.rs @@ -0,0 +1,214 @@ +use arrow::array::{Array, BinaryArray}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some(b"hello".to_vec()), None, Some(b"hello2".to_vec())]; + + let array: BinaryArray = data.into_iter().collect(); + + assert_eq!(array.value(0), b"hello"); + assert_eq!(array.value(1), b""); + assert_eq!(array.value(2), b"hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, b"hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BinaryArray::::new( + ArrowDataType::Binary, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), b""); + assert_eq!(array.value(1), b"hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn empty() { + let array = BinaryArray::::new_empty(ArrowDataType::Binary); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a = BinaryArray::::from_trusted_len_iter(iter); + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(b"hello".as_ref()) + .take(2) + .map(Some) + .map(PolarsResult::Ok); + let a = BinaryArray::::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a: BinaryArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn with_validity() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let array = array.with_validity(None); + + let a = array.validity(); + assert_eq!(a, None); +} + +#[test] +#[should_panic] +fn wrong_offsets() { + let offsets = vec![0, 5, 4].try_into().unwrap(); // invalid offsets + let values = Buffer::from(b"abbbbb".to_vec()); + BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); +} + +#[test] +#[should_panic] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + BinaryArray::::new(ArrowDataType::Int8, offsets, values, None); +} + +#[test] +#[should_panic] +fn value_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds + array.value(0); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + let array = BinaryArray::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +#[should_panic] +fn value_unchecked_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds, + // even if `0` is in bounds + unsafe { array.value_unchecked(0) }; +} + +#[test] +fn debug() { + let array = BinaryArray::::from([Some([1, 2].as_ref()), Some(&[]), None]); + + assert_eq!(format!("{array:?}"), "BinaryArray[[1, 2], [], None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = BinaryArray::::from([Some("hello".as_bytes()), Some(" ".as_bytes()), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" ".as_bytes()), Some("hello".as_bytes())] + ); +} + +#[test] +fn iter_nth() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" ".as_bytes()))); + assert_eq!(array.iter().nth(10), None); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable.rs b/crates/polars/tests/it/arrow/array/binary/mutable.rs new file mode 100644 index 000000000000..d57deb22faa5 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable.rs @@ -0,0 +1,215 @@ +use std::ops::Deref; + +use arrow::array::{BinaryArray, MutableArray, MutableBinaryArray, TryExtendFromSelf}; +use arrow::bitmap::Bitmap; +use polars_error::PolarsError; + +#[test] +fn new() { + assert_eq!(MutableBinaryArray::::new().len(), 0); + + let a = MutableBinaryArray::::with_capacity(2); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert_eq!(a.values().capacity(), 0); + + let a = MutableBinaryArray::::with_capacities(2, 60); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert!(a.values().capacity() >= 60); +} + +#[test] +fn from_iter() { + let iter = (0..3u8).map(|x| Some(vec![x; x as usize])); + let a: MutableBinaryArray = iter.clone().collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = unsafe { MutableBinaryArray::::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let data = [vec![0; 0], vec![1; 1], vec![2; 2]]; + let a: MutableBinaryArray = data.iter().cloned().map(Some).collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_iter(data.iter().cloned().map(Some)); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::try_from_trusted_len_iter::( + data.iter().cloned().map(Some).map(Ok), + ) + .unwrap(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_values_iter(data.iter().cloned()); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn push_null() { + let mut array = MutableBinaryArray::::new(); + array.push::<&str>(None); + + let array: BinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push::>(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push(Some(b"third")); + a.push(Some(b"fourth")); + + for _ in 0..4 { + a.push(Some(b"aaaa")); + } + + a.push(Some(b"bbbb")); + + assert_eq!(a.pop(), Some(b"bbbb".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"fourth".to_vec())); + assert_eq!(a.pop(), Some(b"third".to_vec())); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +#[test] +fn extend_trusted_len_values() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len_values(vec![b"first".to_vec(), b"second".to_vec()].into_iter()); + array.extend_trusted_len_values(vec![b"third".to_vec()].into_iter()); + array.extend_trusted_len(vec![None, Some(b"fourth".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthirdfourth"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 16, 16, 22]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00010111], 5)) + ); +} + +#[test] +fn extend_trusted_len() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len(vec![Some(b"first".to_vec()), Some(b"second".to_vec())].into_iter()); + array.extend_trusted_len(vec![None, Some(b"third".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthird"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 11, 16]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001011], 4)) + ); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBinaryArray::::from([Some(b"aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBinaryArray::::from([Some(b"aa"), None, Some(b"aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable_values.rs b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs new file mode 100644 index 000000000000..c9e4f1da3bbe --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs @@ -0,0 +1,101 @@ +use arrow::array::{MutableArray, MutableBinaryValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableBinaryValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).is_err() + ); +} + +#[test] +fn data_type_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err() + ); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableBinaryValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/binary/to_mutable.rs b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs new file mode 100644 index 000000000000..8f07d3a166b3 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs @@ -0,0 +1,70 @@ +use arrow::array::BinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn not_shared() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: Buffer = vec![0, 1].into(); + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: Buffer = vec![0, 1].into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mod.rs b/crates/polars/tests/it/arrow/array/boolean/mod.rs new file mode 100644 index 000000000000..8b3a4e1e1b70 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mod.rs @@ -0,0 +1,146 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +mod mutable; + +#[test] +fn basics() { + let data = vec![Some(true), None, Some(false)]; + + let array: BooleanArray = data.into_iter().collect(); + + assert_eq!(array.data_type(), &ArrowDataType::Boolean); + + assert!(array.value(0)); + assert!(!array.value(1)); + assert!(!array.value(2)); + assert!(!unsafe { array.value_unchecked(2) }); + assert_eq!(array.values(), &Bitmap::from_u8_slice([0b00000001], 3)); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BooleanArray::new( + ArrowDataType::Boolean, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert!(!array.value(0)); + assert!(!array.value(1)); +} + +#[test] +fn try_new_invalid() { + assert!(BooleanArray::try_new(ArrowDataType::Int32, [true].into(), None).is_err()); + assert!(BooleanArray::try_new( + ArrowDataType::Boolean, + [true].into(), + Some([false, true].into()) + ) + .is_err()); +} + +#[test] +fn with_validity() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let array = BooleanArray::from([Some(true), None, Some(false)]); + assert_eq!(format!("{array:?}"), "BooleanArray[true, None, false]"); +} + +#[test] +fn into_mut_valid() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().right().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().right().unwrap(); +} + +#[test] +fn into_mut_invalid() { + let bitmap = Bitmap::from([true, false, true]); + let _other = bitmap.clone(); // values is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().left().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let _other = validity.clone(); // validity is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().left().unwrap(); +} + +#[test] +fn empty() { + let array = BooleanArray::new_empty(ArrowDataType::Boolean); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a = BooleanArray::from_trusted_len_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(true) + .take(2) + .map(Some) + .map(PolarsResult::Ok); + let a = BooleanArray::try_from_trusted_len_iter(iter.clone()).unwrap(); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::try_from_trusted_len_iter_unchecked(iter).unwrap() }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_trusted_len_values_iter() { + let iter = std::iter::repeat(true).take(2); + let a = BooleanArray::from_trusted_len_values_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a: BooleanArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn into_iter() { + let data = vec![Some(true), None, Some(false)]; + let rev = data.clone().into_iter().rev(); + + let array: BooleanArray = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mutable.rs b/crates/polars/tests/it/arrow/array/boolean/mutable.rs new file mode 100644 index 000000000000..1071a1ed8c37 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mutable.rs @@ -0,0 +1,177 @@ +use arrow::array::{MutableArray, MutableBooleanArray, TryExtendFromSelf}; +use arrow::bitmap::MutableBitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn set() { + let mut a = MutableBooleanArray::from(&[Some(false), Some(true), Some(false)]); + + a.set(1, None); + a.set(0, Some(true)); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(false)]) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false])); +} + +#[test] +fn push() { + let mut a = MutableBooleanArray::new(); + a.push(Some(true)); + a.push(Some(false)); + a.push(None); + a.push_null(); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(false), None, None]) + ); +} + +#[test] +fn pop() { + let mut a = MutableBooleanArray::new(); + a.push(Some(true)); + a.push(Some(false)); + a.push(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBooleanArray::new(); + for _ in 0..4 { + a.push(Some(true)); + } + + for _ in 0..4 { + a.push(Some(false)); + } + + a.push(Some(true)); + + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(true), Some(true), Some(true), Some(false)]) + ); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a = MutableBooleanArray::from_trusted_len_iter(iter); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a: MutableBooleanArray = iter.collect(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = vec![Some(true), Some(true), None] + .into_iter() + .map(PolarsResult::Ok); + let a = MutableBooleanArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true), None])); +} + +#[test] +fn reserve() { + let mut a = MutableBooleanArray::try_new( + ArrowDataType::Boolean, + MutableBitmap::new(), + Some(MutableBitmap::new()), + ) + .unwrap(); + + a.reserve(10); + assert!(a.validity().unwrap().capacity() > 0); + assert!(a.values().capacity() > 0) +} + +#[test] +fn extend_trusted_len() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len(vec![Some(true), Some(false)].into_iter()); + assert_eq!(a.validity(), None); + + a.extend_trusted_len(vec![None, Some(true)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false, true])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len_values(vec![true, true, false].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &MutableBitmap::from([true, true, false])); + + let mut a = MutableBooleanArray::new(); + a.push(None); + a.extend_trusted_len_values(vec![true, false].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([false, true, false])); +} + +#[test] +fn into_iter() { + let ve = MutableBitmap::from([true, false]) + .into_iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); + let ve = MutableBitmap::from([true, false]) + .iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(100); + a.push(true); + a.shrink_to_fit(); + assert_eq!(a.capacity(), 8); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBooleanArray::from([Some(true), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(true), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mod.rs b/crates/polars/tests/it/arrow/array/dictionary/mod.rs new file mode 100644 index 000000000000..e14b065e7536 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mod.rs @@ -0,0 +1,214 @@ +mod mutable; + +use arrow::array::*; +use arrow::datatypes::ArrowDataType; + +#[test] +fn try_new_ok() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.data_type().clone()), false); + let array = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .unwrap(); + + assert_eq!(array.keys(), &PrimitiveArray::from_vec(vec![1i32, 0])); + assert_eq!( + &Utf8Array::::from_slice(["a", "aa"]) as &dyn Array, + array.values().as_ref(), + ); + assert!(!array.is_ordered()); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, a]"); +} + +#[test] +fn try_new_incorrect_key() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(values.data_type().clone()), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_nulls() { + let key: Option = None; + let keys = PrimitiveArray::from_iter([key]); + let value: &[&str] = &[]; + let values = Utf8Array::::from_slice(value); + + let data_type = + ArrowDataType::Dictionary(u32::KEY_TYPE, Box::new(values.data_type().clone()), false); + let r = DictionaryArray::try_new(data_type, keys, values.boxed()).is_ok(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = ArrowDataType::Int32; + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_values_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(ArrowDataType::LargeUtf8), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![2, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds_neg() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![-1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn new_null() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_null(dt, 2); + + assert_eq!(format!("{array:?}"), "DictionaryArray[None, None]"); +} + +#[test] +fn new_empty() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_empty(dt); + + assert_eq!(format!("{array:?}"), "DictionaryArray[]"); +} + +#[test] +fn with_validity() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let array = array.with_validity(Some([true, false].into())); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, None]"); +} + +#[test] +fn rev_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.into_iter(); + assert_eq!(iter.by_ref().rev().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn iter_values() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.values_iter(); + assert_eq!(iter.by_ref().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn keys_values_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + assert_eq!(array.keys_values_iter().collect::>(), vec![1, 0]); +} + +#[test] +fn iter_values_typed() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + let iter = array.values_iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(iter.collect::>(), vec!["aa", "a", "a"]); + + let iter = array.iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!( + iter.collect::>(), + vec![Some("aa"), Some("a"), Some("a")] + ); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.values_iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic_2() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mutable.rs b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs new file mode 100644 index 000000000000..cc8b0774533a --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs @@ -0,0 +1,169 @@ +use std::borrow::Borrow; +use std::fmt::Debug; +use std::hash::Hash; + +use arrow::array::indexable::{AsIndexed, Indexable}; +use arrow::array::*; +use polars_error::PolarsResult; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +#[test] +fn primitive() -> PolarsResult<()> { + let data = vec![Some(1), Some(2), Some(1)]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn utf8_natural() -> PolarsResult<()> { + let data = vec![Some("a"), Some("b"), Some("a")]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn binary_natural() -> PolarsResult<()> { + let data = vec![ + Some("a".as_bytes()), + Some("b".as_bytes()), + Some("a".as_bytes()), + ]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn push_utf8() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + + assert_eq!( + new.values().values(), + MutableUtf8Array::::from_iter_values(["A", "B", "C"].into_iter()).values() + ); + + let mut expected_keys = MutablePrimitiveArray::::from_slice([0, 1]); + expected_keys.push(None); + expected_keys.push(Some(2)); + expected_keys.push(Some(0)); + expected_keys.push(Some(1)); + assert_eq!(*new.keys(), expected_keys); +} + +#[test] +fn into_empty() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let values = new.values().clone(); + let empty = new.into_empty(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); +} + +#[test] +fn from_values() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let mut values = new.values().clone(); + let empty = MutableDictionaryArray::::from_values(values.clone()).unwrap(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); + values.push(Some("A")); + assert!(MutableDictionaryArray::::from_values(values).is_err()); +} + +#[test] +fn try_empty() { + let mut values = MutableUtf8Array::::new(); + MutableDictionaryArray::::try_empty(values.clone()).unwrap(); + values.push(Some("A")); + assert!(MutableDictionaryArray::::try_empty(values.clone()).is_err()); +} + +fn test_push_ex(values: Vec, gen: impl Fn(usize) -> T) +where + M: MutableArray + Indexable + TryPush> + TryExtend> + Default + 'static, + M::Type: Eq + Hash + Debug, + T: AsIndexed + Default + Clone + Eq + Hash, +{ + for is_extend in [false, true] { + let mut set = PlHashSet::new(); + let mut arr = MutableDictionaryArray::::new(); + macro_rules! push { + ($v:expr) => { + if is_extend { + arr.try_extend(std::iter::once($v)) + } else { + arr.try_push($v) + } + }; + } + arr.push_null(); + push!(None).unwrap(); + assert_eq!(arr.len(), 2); + assert_eq!(arr.values().len(), 0); + for (i, v) in values.iter().cloned().enumerate() { + push!(Some(v.clone())).unwrap(); + let is_dup = !set.insert(v.clone()); + if !is_dup { + assert_eq!(arr.values().value_at(i).borrow(), v.as_indexed()); + assert_eq!(arr.keys().value_at(arr.keys().len() - 1), i as u8); + } + assert_eq!(arr.values().len(), set.len()); + assert_eq!(arr.len(), 3 + i); + } + for i in 0..256 - set.len() { + push!(Some(gen(i))).unwrap(); + } + assert!(push!(Some(gen(256))).is_err()); + } +} + +#[test] +fn test_push_utf8_ex() { + test_push_ex::, _>(vec!["a".into(), "b".into(), "a".into()], |i| { + i.to_string() + }) +} + +#[test] +fn test_push_i64_ex() { + test_push_ex::, _>(vec![10, 20, 30, 20], |i| 1000 + i as i64); +} + +#[test] +fn test_big_dict() { + let n = 10; + let strings = (0..10).map(|i| i.to_string()).collect::>(); + let mut arr = MutableDictionaryArray::>::new(); + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + assert_eq!(arr.values().len(), n); + for _ in 0..10_000 { + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + } + assert_eq!(arr.values().len(), n); +} diff --git a/crates/polars/tests/it/arrow/array/equal/boolean.rs b/crates/polars/tests/it/arrow/array/equal/boolean.rs new file mode 100644 index 000000000000..e20be510879f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/boolean.rs @@ -0,0 +1,53 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_boolean_equal() { + let a = BooleanArray::from_slice([false, false, true]); + let b = BooleanArray::from_slice([false, false, true]); + test_equal(&a, &b, true); + + let b = BooleanArray::from_slice([false, false, false]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_null() { + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + test_equal(&a, &b, true); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]); + test_equal(&a, &b, false); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_offset() { + let a = BooleanArray::from_slice(vec![false, true, false, true, false, false, true]); + let b = BooleanArray::from_slice(vec![true, false, false, false, true, false, true, true]); + test_equal(&a, &b, false); + + let a_slice = a.sliced(2, 3); + let b_slice = b.sliced(3, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.sliced(3, 4); + let b_slice = b.sliced(4, 4); + test_equal(&a_slice, &b_slice, false); + + // Elements fill in `u8`'s exactly. + let mut vector = vec![false, false, true, true, true, true, true, true]; + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector.clone()); + test_equal(&a, &b, true); + + // Elements fill in `u8`s + suffix bits. + vector.push(true); + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/dictionary.rs b/crates/polars/tests/it/arrow/array/equal/dictionary.rs new file mode 100644 index 000000000000..b429c71b4e69 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/dictionary.rs @@ -0,0 +1,97 @@ +use arrow::array::*; + +use super::test_equal; + +fn create_dictionary_array(values: &[Option<&str>], keys: &[Option]) -> DictionaryArray { + let keys = Int16Array::from(keys); + let values = Utf8Array::::from(values); + + DictionaryArray::try_from_keys(keys, values.boxed()).unwrap() +} + +#[test] +fn dictionary_equal() { + // (a, b, c), (0, 1, 0, 2) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different len + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(1)], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + test_equal(&a, &b, false); +} + +#[test] +fn dictionary_equal_null() { + // (a, b, c), (1, 2, 1, 3) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), Some(2)], + ); + + // equal to self + test_equal(&a, &a, true); + + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different null position + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), None], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), None, Some(0), Some(2)], + ); + test_equal(&a, &b, false); + + // different nulls in keys and values + let a = create_dictionary_array( + &[Some("a"), Some("b"), None], + &[Some(0), None, Some(0), Some(2)], + ); + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), None], + ); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs new file mode 100644 index 000000000000..04238ab7362f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs @@ -0,0 +1,84 @@ +use arrow::array::{ + FixedSizeListArray, MutableFixedSizeListArray, MutablePrimitiveArray, TryExtend, +}; + +use super::test_equal; + +/// Create a fixed size list of 2 value lengths +fn create_fixed_size_list_array, T: AsRef<[Option]>>( + data: T, +) -> FixedSizeListArray { + let data = data.as_ref().iter().map(|x| { + Some(match x { + Some(x) => x.as_ref().iter().map(|x| Some(*x)).collect::>(), + None => std::iter::repeat(None).take(3).collect::>(), + }) + }); + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + list.into() +} + +#[test] +fn test_fixed_size_list_equal() { + let a = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_fixed_list_null() { + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + /* + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + Some(&[7, 8, 9]), + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, false); + */ + + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + test_equal(&a, &b, false); +} + +#[test] +fn test_fixed_list_offsets() { + // Test the case where offset != 0 + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/list.rs b/crates/polars/tests/it/arrow/array/equal/list.rs new file mode 100644 index 000000000000..34370ad5459e --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/list.rs @@ -0,0 +1,90 @@ +use arrow::array::{Int32Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +use super::test_equal; + +fn create_list_array, T: AsRef<[Option]>>(data: T) -> ListArray { + let iter = data.as_ref().iter().map(|x| { + x.as_ref() + .map(|x| x.as_ref().iter().map(|x| Some(*x)).collect::>()) + }); + let mut array = MutableListArray::>::new(); + array.try_extend(iter).unwrap(); + array.into() +} + +#[test] +fn test_list_equal() { + let a = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_list_null() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + test_equal(&a, &b, true); + + let b = create_list_array([ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ]); + test_equal(&a, &b, false); + + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + test_equal(&a, &b, false); +} + +// Test the case where offset != 0 +#[test] +fn test_list_offsets() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} + +#[test] +fn test_bla() { + let offsets = vec![0, 3, 3, 6].try_into().unwrap(); + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([ + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + ])); + let validity = Bitmap::from([true, false, true]); + let lhs = ListArray::::new(data_type, offsets, values, Some(validity)); + let lhs = lhs.sliced(1, 2); + + let offsets = vec![0, 0, 3].try_into().unwrap(); + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([Some(4), None, Some(6)])); + let validity = Bitmap::from([false, true]); + let rhs = ListArray::::new(data_type, offsets, values, Some(validity)); + + assert_eq!(lhs, rhs); +} diff --git a/crates/polars/tests/it/arrow/array/equal/mod.rs b/crates/polars/tests/it/arrow/array/equal/mod.rs new file mode 100644 index 000000000000..87f7ffeff251 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/mod.rs @@ -0,0 +1,50 @@ +use arrow::array::*; + +mod dictionary; +mod fixed_size_list; +mod list; +mod primitive; +mod utf8; + +pub fn test_equal(lhs: &dyn Array, rhs: &dyn Array, expected: bool) { + // equality is symmetric + assert!(equal(lhs, lhs), "\n{lhs:?}\n{lhs:?}"); + assert!(equal(rhs, rhs), "\n{rhs:?}\n{rhs:?}"); + + assert_eq!(equal(lhs, rhs), expected, "\n{lhs:?}\n{rhs:?}"); + assert_eq!(equal(rhs, lhs), expected, "\n{rhs:?}\n{lhs:?}"); +} + +#[allow(clippy::type_complexity)] +fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + let base = vec![ + Some("hello".to_owned()), + None, + None, + Some("world".to_owned()), + None, + None, + ]; + let not_base = vec![ + Some("hello".to_owned()), + Some("foo".to_owned()), + None, + Some("world".to_owned()), + None, + None, + ]; + vec![ + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("world".to_owned())], + true, + ), + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("arrow".to_owned())], + false, + ), + (base.clone(), base.clone(), true), + (base, not_base, false), + ] +} diff --git a/crates/polars/tests/it/arrow/array/equal/primitive.rs b/crates/polars/tests/it/arrow/array/equal/primitive.rs new file mode 100644 index 000000000000..e50711eb9728 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/primitive.rs @@ -0,0 +1,90 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_primitive() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(3)], + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(4)], + false, + ), + ( + vec![Some(1), Some(2), None], + vec![Some(1), Some(2), None], + true, + ), + ( + vec![Some(1), None, Some(3)], + vec![Some(1), Some(2), None], + false, + ), + ( + vec![Some(1), None, None], + vec![Some(1), Some(2), None], + false, + ), + ]; + + for (lhs, rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let rhs = Int32Array::from(&rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn test_primitive_slice() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + (0, 1), + vec![Some(1), Some(2), Some(3)], + (0, 1), + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + (1, 1), + vec![Some(1), Some(2), Some(3)], + (2, 1), + false, + ), + ( + vec![Some(1), Some(2), None], + (1, 1), + vec![Some(1), None, Some(2)], + (2, 1), + true, + ), + ( + vec![None, Some(2), None], + (1, 1), + vec![None, None, Some(2)], + (2, 1), + true, + ), + ( + vec![Some(1), None, Some(2), None, Some(3)], + (2, 2), + vec![None, Some(2), None, Some(3)], + (1, 2), + true, + ), + ]; + + for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let lhs = lhs.sliced(slice_lhs.0, slice_lhs.1); + let rhs = Int32Array::from(&rhs); + let rhs = rhs.sliced(slice_rhs.0, slice_rhs.1); + + test_equal(&lhs, &rhs, expected); + } +} diff --git a/crates/polars/tests/it/arrow/array/equal/utf8.rs b/crates/polars/tests/it/arrow/array/equal/utf8.rs new file mode 100644 index 000000000000..a9f9e6cff069 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/utf8.rs @@ -0,0 +1,26 @@ +use arrow::array::*; +use arrow::offset::Offset; + +use super::{binary_cases, test_equal}; + +fn test_generic_string_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs.iter().map(|x| x.as_deref()); + let rhs = rhs.iter().map(|x| x.as_deref()); + let lhs = Utf8Array::::from_trusted_len_iter(lhs); + let rhs = Utf8Array::::from_trusted_len_iter(rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn utf8_equal() { + test_generic_string_equal::() +} + +#[test] +fn large_utf8_equal() { + test_generic_string_equal::() +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..12019be64205 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs @@ -0,0 +1,103 @@ +use arrow::array::FixedSizeBinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +mod mutable; + +#[test] +fn basics() { + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + Buffer::from(vec![1, 2, 3, 4, 5, 6]), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 3); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); + + assert_eq!(array.value(0), [1, 2]); + assert_eq!(array.value(2), [5, 6]); + + let array = array.sliced(1, 2); + + assert_eq!(array.value(1), [5, 6]); +} + +#[test] +fn with_validity() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + None, + ); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(format!("{a:?}"), "FixedSizeBinary(2)[[1, 2], None, [5, 6]]"); +} + +#[test] +fn empty() { + let array = FixedSizeBinaryArray::new_empty(ArrowDataType::FixedSizeBinary(2)); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeBinaryArray::new_null(ArrowDataType::FixedSizeBinary(2), 2); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(vec![1u8, 2]).take(2).map(Some); + let a = FixedSizeBinaryArray::from_iter(iter, 2); + assert_eq!(a.len(), 2); +} + +#[test] +fn wrong_size() { + let values = Buffer::from(b"abb".to_vec()); + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, None).is_err() + ); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abba".to_vec()); + let validity = Some([true, false, false].into()); // it should be 2 + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, validity).is_err() + ); +} + +#[test] +fn wrong_data_type() { + let values = Buffer::from(b"abba".to_vec()); + assert!(FixedSizeBinaryArray::try_new(ArrowDataType::Binary, values, None).is_err()); +} + +#[test] +fn to() { + let values = Buffer::from(b"abba".to_vec()); + let a = FixedSizeBinaryArray::new(ArrowDataType::FixedSizeBinary(2), values, None); + + let extension = ArrowDataType::Extension( + "a".to_string(), + Box::new(ArrowDataType::FixedSizeBinary(2)), + None, + ); + let _ = a.to(extension); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs new file mode 100644 index 000000000000..316157087fbb --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs @@ -0,0 +1,173 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basic() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a.len(), 2); + assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(2)); + assert_eq!(a.values(), &Vec::from([1, 2, 3, 4])); + assert_eq!(a.validity(), None); + assert_eq!(a.value(1), &[3, 4]); + assert_eq!(unsafe { a.value_unchecked(1) }, &[3, 4]); +} + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a, a); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2]), + None, + ) + .unwrap(); + assert_eq!(b, b); + assert!(a != b); + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([false, true])), + ) + .unwrap(); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); +} + +#[test] +fn try_from_iter() { + let array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + assert_eq!(array.len(), 4); +} + +#[test] +fn push_null() { + let mut array = MutableFixedSizeBinaryArray::new(2); + array.push::<&[u8]>(None); + + let array: FixedSizeBinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push::<&[u8]>(None); + a.push(Some(b"bb")); + a.push::<&[u8]>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some(b"bb".to_vec())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"aa".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push(Some(b"bb")); + a.push(Some(b"cc")); + a.push(Some(b"dd")); + + for _ in 0..4 { + a.push(Some(b"11")); + } + + a.push(Some(b"22")); + + assert_eq!(a.pop(), Some(b"22".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::try_from_iter( + vec![ + Some(b"aa"), + Some(b"bb"), + Some(b"cc"), + Some(b"dd"), + Some(b"11"), + ], + 2, + ) + .unwrap() + ); +} + +#[test] +fn as_arc() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_arc(); + assert_eq!(array.len(), 4); +} + +#[test] +fn as_box() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_box(); + assert_eq!(array.len(), 4); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut array = MutableFixedSizeBinaryArray::with_capacity(2, 100); + array.push(Some([1, 2])); + array.shrink_to_fit(); + assert_eq!(array.capacity(), 1); +} + +#[test] +fn extend_from_self() { + let mut a = MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None, Some([1u8, 2u8]), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs new file mode 100644 index 000000000000..d178b27e190b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -0,0 +1,102 @@ +mod mutable; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; + +fn data() -> FixedSizeListArray { + let values = Int32Array::from_slice([10, 20, 0, 0]); + + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a", values.data_type().clone(), true)), + 2, + ), + values.boxed(), + Some([true, false].into()), + ) + .unwrap() +} + +#[test] +fn basics() { + let array = data(); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 2); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false]))); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([10, 20])); + assert_eq!(array.value(1).as_ref(), Int32Array::from_slice([0, 0])); + + let array = array.sliced(1, 1); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([0, 0])); +} + +#[test] +fn with_validity() { + let array = data(); + + let a = array.with_validity(None); + assert!(a.validity().is_none()); +} + +#[test] +fn debug() { + let array = data(); + + assert_eq!(format!("{array:?}"), "FixedSizeListArray[[10, 20], None]"); +} + +#[test] +fn empty() { + let array = FixedSizeListArray::new_empty(ArrowDataType::FixedSizeList( + Box::new(Field::new("a", ArrowDataType::Int32, true)), + 2, + )); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeListArray::new_null( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + 2, + ); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn wrong_size() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + values.boxed(), + None + ) + .is_err()); +} + +#[test] +fn wrong_len() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + values.boxed(), + Some([true, false, false].into()), // it should be 2 + ) + .is_err()); +} + +#[test] +fn wrong_data_type() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::Binary, + values.boxed(), + Some([true, false, false].into()), // it should be 2 + ) + .is_err()); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs new file mode 100644 index 000000000000..23ea53231059 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs @@ -0,0 +1,88 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn new_with_field() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new_with_field( + MutablePrimitiveArray::::new(), + "custom_items", + false, + 3, + ); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + assert_eq!( + list.data_type(), + &ArrowDataType::FixedSizeList( + Box::new(Field::new("custom_items", ArrowDataType::Int32, false)), + 3 + ) + ); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: FixedSizeListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + b.try_extend(expected).unwrap(); + let b: FixedSizeListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/growable/binary.rs b/crates/polars/tests/it/arrow/array/growable/binary.rs new file mode 100644 index 000000000000..20c0cd31081b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/binary.rs @@ -0,0 +1,97 @@ +use arrow::array::growable::{Growable, GrowableBinary}; +use arrow::array::BinaryArray; + +#[test] +fn no_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None]); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn with_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn test_string_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn test_multiple_with_validity() { + let array1 = BinaryArray::::from_slice([b"hello", b"world"]); + let array2 = BinaryArray::::from([Some("1"), None]); + + let mut a = GrowableBinary::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("hello"), Some("world"), Some("1"), None]); + assert_eq!(result, expected); +} + +#[test] +fn test_string_null_offset_validity() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([None, Some("defh"), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/boolean.rs b/crates/polars/tests/it/arrow/array/growable/boolean.rs new file mode 100644 index 000000000000..b6721029cb81 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/boolean.rs @@ -0,0 +1,19 @@ +use arrow::array::growable::{Growable, GrowableBoolean}; +use arrow::array::BooleanArray; + +#[test] +fn test_bool() { + let array = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); + + let mut a = GrowableBoolean::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: BooleanArray = a.into(); + + let expected = BooleanArray::from(vec![Some(true), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/dictionary.rs b/crates/polars/tests/it/arrow/array/growable/dictionary.rs new file mode 100644 index 000000000000..e2a48275d7ae --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/dictionary.rs @@ -0,0 +1,72 @@ +use arrow::array::growable::{Growable, GrowableDictionary}; +use arrow::array::*; +use polars_error::PolarsResult; + +#[test] +fn test_single() -> PolarsResult<()> { + let original_data = vec![Some("a"), Some("b"), Some("a")]; + + let data = original_data.clone(); + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(data)?; + let array = array.into(); + + // same values, less keys + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from_vec(vec![1, 0]), + Box::new(Utf8Array::::from(&original_data)), + ) + .unwrap(); + + let mut growable = GrowableDictionary::new(&[&array], false, 0); + + unsafe { + growable.extend(0, 1, 2); + } + assert_eq!(growable.len(), 2); + + let result: DictionaryArray = growable.into(); + + assert_eq!(result, expected); + Ok(()) +} + +#[test] +fn test_multi() -> PolarsResult<()> { + let mut original_data1 = vec![Some("a"), Some("b"), None, Some("a")]; + let original_data2 = vec![Some("c"), Some("b"), None, Some("a")]; + + let data1 = original_data1.clone(); + let data2 = original_data2.clone(); + + let mut array1 = MutableDictionaryArray::>::new(); + array1.try_extend(data1)?; + let array1: DictionaryArray = array1.into(); + + let mut array2 = MutableDictionaryArray::>::new(); + array2.try_extend(data2)?; + let array2: DictionaryArray = array2.into(); + + // same values, less keys + original_data1.extend(original_data2.iter().cloned()); + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from(&[Some(1), None, Some(3), None]), + Utf8Array::::from_slice(["a", "b", "c", "b", "a"]).boxed(), + ) + .unwrap(); + + let mut growable = GrowableDictionary::new(&[&array1, &array2], false, 0); + + unsafe { + growable.extend(0, 1, 2); + } + unsafe { + growable.extend(1, 1, 2); + } + assert_eq!(growable.len(), 4); + + let result: DictionaryArray = growable.into(); + + assert_eq!(result, expected); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs b/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs new file mode 100644 index 000000000000..9ebb631f682c --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs @@ -0,0 +1,146 @@ +use arrow::array::growable::{Growable, GrowableFixedSizeBinary}; +use arrow::array::FixedSizeBinaryArray; + +/// tests extending from a variable-sized (strings and binary) array w/ offset with nulls +#[test] +fn basic() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn offsets() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], 2); + let array = array.sliced(1, 3); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some(b"bc"), None, Some(b"fh")], 2); + assert_eq!(result, expected); +} + +#[test] +fn multiple_with_validity() { + let array1 = FixedSizeBinaryArray::from_iter(vec![Some("hello"), Some("world")], 5); + let array2 = FixedSizeBinaryArray::from_iter(vec![Some("12345"), None], 5); + + let mut a = GrowableFixedSizeBinary::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = + FixedSizeBinaryArray::from_iter(vec![Some("hello"), Some("world"), Some("12345"), None], 5); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let array = FixedSizeBinaryArray::from_iter(vec![Some("aa"), Some("bc"), None, Some("fh")], 2); + let array = array.sliced(1, 3); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![None, Some("fh"), None], 2); + assert_eq!(result, expected); +} + +#[test] +fn sized_offsets() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(&[0, 0]), Some(&[0, 1]), Some(&[0, 2])], 2); + let array = array.sliced(1, 2); + // = [[0, 1], [0, 2]] due to the offset = 1 + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 1); + } + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 2); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some(&[0, 2]), Some(&[0, 1])], 2); + assert_eq!(result, expected); +} + +/// to, as_box, as_arc +#[test] +fn as_box() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 2); + } + + let result = a.as_box(); + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(&expected, result); +} + +/// as_arc +#[test] +fn as_arc() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 2); + } + + let result = a.as_arc(); + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(&expected, result); +} diff --git a/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs b/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs new file mode 100644 index 000000000000..dcdc25d1bda9 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs @@ -0,0 +1,95 @@ +use arrow::array::growable::{Growable, GrowableFixedSizeList}; +use arrow::array::{ + FixedSizeListArray, MutableFixedSizeListArray, MutablePrimitiveArray, TryExtend, +}; + +fn create_list_array(data: Vec>>>) -> FixedSizeListArray { + let mut array = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + array.try_extend(data).unwrap(); + array.into() +} + +#[test] +fn basic() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + Some(vec![Some(7i32), Some(8), Some(9)]), + ]; + + let array = create_list_array(data); + + let mut a = GrowableFixedSizeList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![Some(vec![Some(1i32), Some(2), Some(3)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offset() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableFixedSizeList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), Some(7), Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn test_from_two_lists() { + let data_1 = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array_1 = create_list_array(data_1); + + let data_2 = vec![ + Some(vec![Some(8i32), Some(7), Some(6)]), + Some(vec![Some(5i32), None, Some(4)]), + Some(vec![Some(2i32), Some(1), Some(0)]), + ]; + let array_2 = create_list_array(data_2); + + let mut a = GrowableFixedSizeList::new(vec![&array_1, &array_2], false, 6); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 1); + } + assert_eq!(a.len(), 3); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(5i32), None, Some(4)]), + ]; + let expected = create_list_array(expected); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/list.rs b/crates/polars/tests/it/arrow/array/growable/list.rs new file mode 100644 index 000000000000..1bc0985ceb4f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/list.rs @@ -0,0 +1,147 @@ +use arrow::array::growable::{Growable, GrowableList}; +use arrow::array::{Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend}; +use arrow::datatypes::ArrowDataType; + +fn create_list_array(data: Vec>>>) -> ListArray { + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + array.into() +} + +#[test] +fn extension() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + + let array = create_list_array(data); + + let data_type = + ArrowDataType::Extension("ext".to_owned(), Box::new(array.data_type().clone()), None); + let array_ext = ListArray::new( + data_type, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + + let mut a = GrowableList::new(vec![&array_ext], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + assert_eq!(array_ext.data_type(), result.data_type()); +} + +#[test] +fn basic() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + + let array = create_list_array(data); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(1i32), Some(2), Some(3)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offset() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), Some(7), Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offsets() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), None, Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn test_from_two_lists() { + let data_1 = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array_1 = create_list_array(data_1); + + let data_2 = vec![ + Some(vec![Some(8i32), Some(7), Some(6)]), + Some(vec![Some(5i32), None, Some(4)]), + Some(vec![Some(2i32), Some(1), Some(0)]), + ]; + let array_2 = create_list_array(data_2); + + let mut a = GrowableList::new(vec![&array_1, &array_2], false, 6); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 1); + } + assert_eq!(a.len(), 3); + + let result: ListArray = a.into(); + + let expected = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(5i32), None, Some(4)]), + ]; + let expected = create_list_array(expected); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/mod.rs b/crates/polars/tests/it/arrow/array/growable/mod.rs new file mode 100644 index 000000000000..43496a1e95b1 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/mod.rs @@ -0,0 +1,75 @@ +mod binary; +mod boolean; +mod dictionary; +mod fixed_binary; +mod fixed_size_list; +mod list; +mod null; +mod primitive; +mod struct_; +mod utf8; + +use arrow::array::growable::make_growable; +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn test_make_growable() { + let array = Int32Array::from_slice([1, 2]); + make_growable(&[&array], false, 2); + + let array = Utf8Array::::from_slice(["a", "aa"]); + make_growable(&[&array], false, 2); + + let array = Utf8Array::::from_slice(["a", "aa"]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + b"abcd".to_vec().into(), + None, + ); + make_growable(&[&array], false, 2); +} + +#[test] +fn test_make_growable_extension() { + let array = DictionaryArray::try_from_keys( + Int32Array::from_slice([1, 0]), + Int32Array::from_slice([1, 2]).boxed(), + ) + .unwrap(); + make_growable(&[&array], false, 2); + + let data_type = + ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None); + let array = Int32Array::from_slice([1, 2]).to(data_type.clone()); + let array_grown = make_growable(&[&array], false, 2).as_box(); + assert_eq!(array_grown.data_type(), &data_type); + + let data_type = ArrowDataType::Extension( + "ext".to_owned(), + Box::new(ArrowDataType::Struct(vec![Field::new( + "a", + ArrowDataType::Int32, + false, + )])), + None, + ); + let array = StructArray::new( + data_type.clone(), + vec![Int32Array::from_slice([1, 2]).boxed()], + None, + ); + let array_grown = make_growable(&[&array], false, 2).as_box(); + assert_eq!(array_grown.data_type(), &data_type); +} diff --git a/crates/polars/tests/it/arrow/array/growable/null.rs b/crates/polars/tests/it/arrow/array/growable/null.rs new file mode 100644 index 000000000000..2d6a118a117c --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/null.rs @@ -0,0 +1,21 @@ +use arrow::array::growable::{Growable, GrowableNull}; +use arrow::array::NullArray; +use arrow::datatypes::ArrowDataType; + +#[test] +fn null() { + let mut mutable = GrowableNull::default(); + + unsafe { + mutable.extend(0, 1, 2); + } + unsafe { + mutable.extend(1, 0, 1); + } + assert_eq!(mutable.len(), 3); + + let result: NullArray = mutable.into(); + + let expected = NullArray::new(ArrowDataType::Null, 3); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/primitive.rs b/crates/polars/tests/it/arrow/array/growable/primitive.rs new file mode 100644 index 000000000000..37c105f2c728 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/primitive.rs @@ -0,0 +1,82 @@ +use arrow::array::growable::{Growable, GrowablePrimitive}; +use arrow::array::PrimitiveArray; + +/// tests extending from a primitive array w/ offset nor nulls +#[test] +fn basics() { + let b = PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]); + let mut a = GrowablePrimitive::new(vec![&b], false, 3); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![Some(1), Some(2)]); + assert_eq!(result, expected); +} + +/// tests extending from a primitive array with offset w/ nulls +#[test] +fn offset() { + let b = PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], false, 2); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![Some(2), Some(3)]); + assert_eq!(result, expected); +} + +/// tests extending from a primitive array with offset and nulls +#[test] +fn null_offset() { + let b = PrimitiveArray::::from(vec![Some(1), None, Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], false, 2); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![None, Some(3)]); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let b = PrimitiveArray::::from(&[Some(1), Some(2), Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], true, 2); + unsafe { + a.extend(0, 0, 2); + } + a.extend_validity(3); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 6); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(&[Some(2), Some(3), None, None, None, Some(3)]); + assert_eq!(result, expected); +} + +#[test] +fn joining_arrays() { + let b = PrimitiveArray::::from(&[Some(1), Some(2), Some(3)]); + let c = PrimitiveArray::::from(&[Some(4), Some(5), Some(6)]); + let mut a = GrowablePrimitive::new(vec![&b, &c], false, 4); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 2); + } + assert_eq!(a.len(), 4); + let result: PrimitiveArray = a.into(); + + let expected = PrimitiveArray::::from(&[Some(1), Some(2), Some(5), Some(6)]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/struct_.rs b/crates/polars/tests/it/arrow/array/growable/struct_.rs new file mode 100644 index 000000000000..809e70749f09 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/struct_.rs @@ -0,0 +1,139 @@ +use arrow::array::growable::{Growable, GrowableStruct}; +use arrow::array::{Array, PrimitiveArray, StructArray, Utf8Array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; + +fn some_values() -> (ArrowDataType, Vec>) { + let strings: Box = Box::new(Utf8Array::::from([ + Some("a"), + Some("aa"), + None, + Some("mark"), + Some("doe"), + ])); + let ints: Box = Box::new(PrimitiveArray::::from(&[ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])); + let fields = vec![ + Field::new("f1", ArrowDataType::Utf8, true), + Field::new("f2", ArrowDataType::Int32, true), + ]; + (ArrowDataType::Struct(fields), vec![strings, ints]) +} + +#[test] +fn basic() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], + None, + ); + assert_eq!(result, expected) +} + +#[test] +fn offset() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None).sliced(1, 3); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(2, 2), values[1].sliced(2, 2)], + None, + ); + + assert_eq!(result, expected); +} + +#[test] +fn nulls() { + let (fields, values) = some_values(); + + let array = StructArray::new( + fields.clone(), + values.clone(), + Some(Bitmap::from_u8_slice([0b00000010], 5)), + ); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], + Some(Bitmap::from_u8_slice([0b00000010], 5).sliced(1, 2)), + ); + + assert_eq!(result, expected) +} + +#[test] +fn many() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None); + + let mut mutable = GrowableStruct::new(vec![&array, &array], true, 0); + + unsafe { + mutable.extend(0, 1, 2); + } + unsafe { + mutable.extend(1, 0, 2); + } + mutable.extend_validity(1); + assert_eq!(mutable.len(), 5); + let result = mutable.as_box(); + + let expected_string: Box = Box::new(Utf8Array::::from([ + Some("aa"), + None, + Some("a"), + Some("aa"), + None, + ])); + let expected_int: Box = Box::new(PrimitiveArray::::from(vec![ + Some(2), + Some(3), + Some(1), + Some(2), + None, + ])); + + let expected = StructArray::new( + fields, + vec![expected_string, expected_int], + Some(Bitmap::from([true, true, true, true, false])), + ); + assert_eq!(expected, result.as_ref()) +} diff --git a/crates/polars/tests/it/arrow/array/growable/utf8.rs b/crates/polars/tests/it/arrow/array/growable/utf8.rs new file mode 100644 index 000000000000..af2be2ab9867 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/utf8.rs @@ -0,0 +1,97 @@ +use arrow::array::growable::{Growable, GrowableUtf8}; +use arrow::array::Utf8Array; + +/// tests extending from a variable-sized (strings and binary) array w/ offset with nulls +#[test] +fn validity() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None]); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn offsets() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn offsets2() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn multiple_with_validity() { + let array1 = Utf8Array::::from_slice(["hello", "world"]); + let array2 = Utf8Array::::from([Some("1"), None]); + + let mut a = GrowableUtf8::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("hello"), Some("world"), Some("1"), None]); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([None, Some("defh"), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/list/mod.rs b/crates/polars/tests/it/arrow/array/list/mod.rs new file mode 100644 index 000000000000..77e443781c17 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mod.rs @@ -0,0 +1,70 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +mod mutable; + +#[test] +fn debug() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type, + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + assert_eq!(format!("{array:?}"), "ListArray[[1, 2], [], [3], [4, 5]]"); +} + +#[test] +#[should_panic] +fn test_nested_panic() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type.clone(), + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + // The datatype for the nested array has to be created considering + // the nested structure of the child data + let _ = ListArray::::new( + data_type, + vec![0, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); +} + +#[test] +fn test_nested_display() { + let values = Buffer::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type, + vec![0, 2, 4, 7, 7, 8, 10].try_into().unwrap(), + Box::new(values), + None, + ); + + let data_type = ListArray::::default_datatype(array.data_type().clone()); + let nested = ListArray::::new( + data_type, + vec![0, 2, 5, 6].try_into().unwrap(), + Box::new(array), + None, + ); + + let expected = "ListArray[[[1, 2], [3, 4]], [[5, 6, 7], [], [8]], [[9, 10]]]"; + assert_eq!(format!("{nested:?}"), expected); +} diff --git a/crates/polars/tests/it/arrow/array/list/mutable.rs b/crates/polars/tests/it/arrow/array/list/mutable.rs new file mode 100644 index 000000000000..2d4ba0c4d2f1 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mutable.rs @@ -0,0 +1,76 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basics() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + let array: ListArray = array.into(); + + let values = PrimitiveArray::::new( + ArrowDataType::Int32, + Buffer::from(vec![1, 2, 3, 4, 0, 6]), + Some(Bitmap::from([true, true, true, true, false, true])), + ); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let expected = ListArray::::new( + data_type, + vec![0, 3, 3, 6].try_into().unwrap(), + Box::new(values), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(expected, array); +} + +#[test] +fn with_capacity() { + let array = MutableListArray::>::with_capacity(10); + assert!(array.offsets().capacity() >= 10); + assert_eq!(array.offsets().len_proxy(), 0); + assert_eq!(array.values().values().capacity(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn push() { + let mut array = MutableListArray::>::new(); + array + .try_push(Some(vec![Some(1i32), Some(2), Some(3)])) + .unwrap(); + assert_eq!(array.len(), 1); + assert_eq!(array.values().values().as_ref(), [1, 2, 3]); + assert_eq!(array.offsets().as_slice(), [0, 3]); + assert_eq!(array.validity(), None); +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableListArray::>::new(); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: ListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableListArray::>::new(); + b.try_extend(expected).unwrap(); + let b: ListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/map/mod.rs b/crates/polars/tests/it/arrow/array/map/mod.rs new file mode 100644 index 000000000000..30d1d263a9d7 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/map/mod.rs @@ -0,0 +1,52 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn basics() { + let dt = ArrowDataType::Struct(vec![ + Field::new("a", ArrowDataType::Utf8, true), + Field::new("b", ArrowDataType::Utf8, true), + ]); + let data_type = ArrowDataType::Map(Box::new(Field::new("a", dt.clone(), true)), false); + + let field = StructArray::new( + dt.clone(), + vec![ + Box::new(Utf8Array::::from_slice(["a", "aa", "aaa"])) as _, + Box::new(Utf8Array::::from_slice(["b", "bb", "bbb"])), + ], + None, + ); + + let array = MapArray::new( + data_type, + vec![0, 1, 2].try_into().unwrap(), + Box::new(field), + None, + ); + + assert_eq!( + array.value(0), + Box::new(StructArray::new( + dt.clone(), + vec![ + Box::new(Utf8Array::::from_slice(["a"])) as _, + Box::new(Utf8Array::::from_slice(["b"])), + ], + None, + )) as Box + ); + + let sliced = array.sliced(1, 1); + assert_eq!( + sliced.value(0), + Box::new(StructArray::new( + dt, + vec![ + Box::new(Utf8Array::::from_slice(["aa"])) as _, + Box::new(Utf8Array::::from_slice(["bb"])), + ], + None, + )) as Box + ); +} diff --git a/crates/polars/tests/it/arrow/array/mod.rs b/crates/polars/tests/it/arrow/array/mod.rs new file mode 100644 index 000000000000..89fbe3f19ad5 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/mod.rs @@ -0,0 +1,141 @@ +mod binary; +mod boolean; +mod dictionary; +mod equal; +mod fixed_size_binary; +mod fixed_size_list; +mod growable; +mod list; +mod map; +mod primitive; +mod struct_; +mod union; +mod utf8; + +use arrow::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field, UnionMode}; + +#[test] +fn nulls() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 10); + assert!(a); + + // unions' null count is always 0 + let datatypes = vec![ + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 0); + assert!(a); +} + +#[test] +fn empty() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::List(Box::new(Field::new( + "a", + ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None), + true, + ))), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ]; + let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0); + assert!(a); +} + +#[test] +fn empty_extension() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ]; + let a = datatypes + .into_iter() + .map(|dt| ArrowDataType::Extension("ext".to_owned(), Box::new(dt), None)) + .all(|x| { + let a = new_empty_array(x); + a.len() == 0 && matches!(a.data_type(), ArrowDataType::Extension(_, _, _)) + }); + assert!(a); +} + +#[test] +fn test_clone() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ]; + let a = datatypes + .into_iter() + .all(|x| clone(new_null_array(x.clone(), 10).as_ref()) == new_null_array(x, 10)); + assert!(a); +} + +#[test] +fn test_with_validity() { + let arr = PrimitiveArray::from_slice([1i32, 2, 3]); + let validity = Bitmap::from(&[true, false, true]); + let arr = arr.with_validity(Some(validity)); + let arr_ref = arr.as_any().downcast_ref::>().unwrap(); + + let expected = PrimitiveArray::from(&[Some(1i32), None, Some(3)]); + assert_eq!(arr_ref, &expected); +} + +// check that we ca derive stuff +#[derive(PartialEq, Clone, Debug)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/array/primitive/fmt.rs b/crates/polars/tests/it/arrow/array/primitive/fmt.rs new file mode 100644 index 000000000000..6ab0ffa1ee8b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/fmt.rs @@ -0,0 +1,224 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::types::{days_ms, months_days_ns}; + +#[test] +fn debug_int32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]); + assert_eq!(format!("{array:?}"), "Int32[1, None, 2]"); +} + +#[test] +fn debug_date32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Date32); + assert_eq!(format!("{array:?}"), "Date32[1970-01-02, None, 1970-01-03]"); +} + +#[test] +fn debug_time32s() { + let array = + Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time32(TimeUnit::Second)); + assert_eq!( + format!("{array:?}"), + "Time32(Second)[00:00:01, None, 00:00:02]" + ); +} + +#[test] +fn debug_time32ms() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time32(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Time32(Millisecond)[00:00:00.001, None, 00:00:00.002]" + ); +} + +#[test] +fn debug_interval_d() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); + assert_eq!(format!("{array:?}"), "Interval(YearMonth)[1m, None, 2m]"); +} + +#[test] +fn debug_int64() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Int64); + assert_eq!(format!("{array:?}"), "Int64[1, None, 2]"); +} + +#[test] +fn debug_date64() { + let array = Int64Array::from(&[Some(1), None, Some(86400000)]).to(ArrowDataType::Date64); + assert_eq!(format!("{array:?}"), "Date64[1970-01-01, None, 1970-01-02]"); +} + +#[test] +fn debug_time64us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time64(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Microsecond)[00:00:00.000001, None, 00:00:00.000002]" + ); +} + +#[test] +fn debug_time64ns() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time64(TimeUnit::Nanosecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Nanosecond)[00:00:00.000000001, None, 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_s() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Second, None)[1970-01-01 00:00:01, None, 1970-01-01 00:00:02]" + ); +} + +#[test] +fn debug_timestamp_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Millisecond, None)[1970-01-01 00:00:00.001, None, 1970-01-01 00:00:00.002]" + ); +} + +#[test] +fn debug_timestamp_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Microsecond, None)[1970-01-01 00:00:00.000001, None, 1970-01-01 00:00:00.000002]" + ); +} + +#[test] +fn debug_timestamp_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, None)[1970-01-01 00:00:00.000000001, None, 1970-01-01 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_tz_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("+02:00".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"+02:00\"))[1970-01-01 02:00:00.000000001 +02:00, None, 1970-01-01 02:00:00.000000002 +02:00]" + ); +} + +#[test] +fn debug_timestamp_tz_not_parsable() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("aa".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"aa\"))[1 (aa), None, 2 (aa)]" + ); +} + +#[cfg(feature = "chrono-tz")] +#[test] +fn debug_timestamp_tz1_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("Europe/Lisbon".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"Europe/Lisbon\"))[1970-01-01 01:00:00.000000001 CET, None, 1970-01-01 01:00:00.000000002 CET]" + ); +} + +#[test] +fn debug_duration_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Millisecond)[1ms, None, 2ms]" + ); +} + +#[test] +fn debug_duration_s() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Duration(TimeUnit::Second)); + assert_eq!(format!("{array:?}"), "Duration(Second)[1s, None, 2s]"); +} + +#[test] +fn debug_duration_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Microsecond)[1us, None, 2us]" + ); +} + +#[test] +fn debug_duration_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(format!("{array:?}"), "Duration(Nanosecond)[1ns, None, 2ns]"); +} + +#[test] +fn debug_decimal() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 2)); + assert_eq!(format!("{array:?}"), "Decimal(5, 2)[123.45, None, 234.56]"); +} + +#[test] +fn debug_decimal1() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 1)); + assert_eq!(format!("{array:?}"), "Decimal(5, 1)[1234.5, None, 2345.6]"); +} + +#[test] +fn debug_interval_days_ms() { + let array = DaysMsArray::from(&[Some(days_ms::new(1, 1)), None, Some(days_ms::new(2, 2))]); + assert_eq!( + format!("{array:?}"), + "Interval(DayTime)[1d1ms, None, 2d2ms]" + ); +} + +#[test] +fn debug_months_days_ns() { + let data = &[ + Some(months_days_ns::new(1, 1, 2)), + None, + Some(months_days_ns::new(2, 3, 3)), + ]; + + let array = MonthsDaysNsArray::from(&data); + + assert_eq!( + format!("{array:?}"), + "Interval(MonthDayNano)[1m1d2ns, None, 2m3d3ns]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mod.rs b/crates/polars/tests/it/arrow/array/primitive/mod.rs new file mode 100644 index 000000000000..e36b68f5a6a7 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mod.rs @@ -0,0 +1,140 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::types::months_days_ns; + +mod fmt; +mod mutable; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some(1), None, Some(10)]; + + let array = Int32Array::from_iter(data); + + assert_eq!(array.value(0), 1); + assert_eq!(array.value(1), 0); + assert_eq!(array.value(2), 10); + assert_eq!(array.values().as_slice(), &[1, 0, 10]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Int32Array::new( + ArrowDataType::Int32, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), 0); + assert_eq!(array.value(1), 10); + assert_eq!(array.values().as_slice(), &[0, 10]); + + unsafe { + assert_eq!(array.value_unchecked(0), 0); + assert_eq!(array.value_unchecked(1), 10); + } +} + +#[test] +fn empty() { + let array = Int32Array::new_empty(ArrowDataType::Int32); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let data = vec![Some(1), None, Some(10)]; + + let array = PrimitiveArray::from(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_iter(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_iter(data.into_iter()); + assert_eq!(array.len(), 3); + + let data = vec![1i32, 2, 3]; + + let array = PrimitiveArray::from_values(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_values_iter(data.into_iter()); + assert_eq!(array.len(), 3); +} + +#[test] +fn months_days_ns_from_slice() { + let data = &[ + months_days_ns::new(1, 1, 2), + months_days_ns::new(1, 1, 3), + months_days_ns::new(2, 3, 3), + ]; + + let array = MonthsDaysNsArray::from_slice(data); + + let a = array.values().as_slice(); + assert_eq!(a, data.as_ref()); +} + +#[test] +fn wrong_data_type() { + let values = Buffer::from(b"abbb".to_vec()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, None).is_err()); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abbb".to_vec()); + let validity = Some([true, false].into()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, validity).is_err()); +} + +#[test] +fn into_mut_1() { + let values = Buffer::::from(vec![0, 1]); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let a = validity.clone(); // cloned values + assert_eq!(a, validity); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn into_iter() { + let data = vec![Some(1), None, Some(10)]; + let rev = data.clone().into_iter().rev(); + + let array: Int32Array = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mutable.rs b/crates/polars/tests/it/arrow/array/primitive/mutable.rs new file mode 100644 index 000000000000..bd4d3831dc82 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mutable.rs @@ -0,0 +1,328 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn from_and_into_data() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + assert_eq!(a.len(), 2); + let (a, b, c) = a.into_inner(); + assert_eq!(a, ArrowDataType::Int32); + assert_eq!(b, Vec::from([1i32, 0])); + assert_eq!(c, Some(MutableBitmap::from([true, false]))); +} + +#[test] +fn from_vec() { + let a = MutablePrimitiveArray::from_vec(Vec::from([1i32, 0])); + assert_eq!(a.len(), 2); +} + +#[test] +fn to() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); +} + +#[test] +fn values_mut_slice() { + let mut a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let values = a.values_mut_slice(); + + values[0] = 10; + assert_eq!(a.values()[0], 10); +} + +#[test] +fn push() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push_null(); + assert_eq!(a.len(), 3); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(!a.is_valid(2)); + + assert_eq!(a.values(), &Vec::from([1, 0, 0])); +} + +#[test] +fn pop() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push(Some(2)); + a.push_null(); + assert_eq!(a.pop(), None); + assert_eq!(a.pop(), Some(2)); + assert_eq!(a.pop(), None); + assert!(a.is_valid(0)); + assert_eq!(a.values(), &Vec::from([1])); + assert_eq!(a.pop(), Some(1)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutablePrimitiveArray::::new(); + for v in 0..8 { + a.push(Some(v)); + } + + a.push(Some(8)); + assert_eq!(a.pop(), Some(8)); + assert_eq!(a.pop(), Some(7)); + assert_eq!(a.pop(), Some(6)); + assert_eq!(a.pop(), Some(5)); + assert_eq!(a.pop(), Some(4)); + assert_eq!(a.len(), 4); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + assert!(a.is_valid(2)); + assert!(a.is_valid(3)); + assert_eq!(a.values(), &Vec::from([0, 1, 2, 3])); +} + +#[test] +fn set() { + let mut a = MutablePrimitiveArray::::from([Some(1), None]); + + a.set(0, Some(2)); + a.set(1, Some(1)); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 1])); + + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + + a.set(0, Some(2)); + a.set(1, None); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 0])); +} + +#[test] +fn from_iter() { + let a = MutablePrimitiveArray::::from_iter((0..2).map(Some)); + assert_eq!(a.len(), 2); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); +} + +#[test] +fn natural_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).into_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_box() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_box(); + assert_eq!(a.len(), 2); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut a = MutablePrimitiveArray::::with_capacity(100); + a.push(Some(1)); + a.try_push(None).unwrap(); + assert!(a.capacity() >= 100); + (&mut a as &mut dyn MutableArray).shrink_to_fit(); + assert_eq!(a.capacity(), 2); +} + +#[test] +fn only_nulls() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.push(None); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([false, false]))); +} + +#[test] +fn from_trusted_len() { + let a = + MutablePrimitiveArray::::from_trusted_len_iter(vec![Some(1), None, None].into_iter()); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false, false]))); + + let a = unsafe { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked( + vec![Some(1), None].into_iter(), + ) + }; + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false]))); +} + +#[test] +fn extend_trusted_len() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + a.extend_trusted_len(vec![None, Some(4)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &Vec::::from([1, 2, 0, 4])); +} + +#[test] +fn extend_constant_no_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, Some(3)); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 3, 3])); +} + +#[test] +fn extend_constant_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([1, 0, 0])); +} + +#[test] +fn extend_constant_validity_inverse() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, Some(1)); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &Vec::::from([0, 1, 1])); +} + +#[test] +fn extend_constant_validity_none() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([0, 0, 0])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len_values(vec![1, 2, 3].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_trusted_len_values(vec![1, 2].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn extend_from_slice() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_from_slice(&[1, 2, 3]); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_from_slice(&[1, 2]); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn set_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + + // test that upon conversion to array the bitmap is set to None + let arr: PrimitiveArray<_> = a.clone().into(); + assert_eq!(arr.validity(), None); + + // test set_validity + a.set_validity(Some(MutableBitmap::from([false, true]))); + assert_eq!(a.validity(), Some(&MutableBitmap::from([false, true]))); +} + +#[test] +fn set_values() { + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + a.set_values(Vec::from([1, 3])); + assert_eq!(a.values().as_slice(), [1, 3]); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(Some(1)).take(2).map(PolarsResult::Ok); + let a = MutablePrimitiveArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutablePrimitiveArray::from([Some(1), Some(1)])); +} + +#[test] +fn wrong_data_type() { + assert!(MutablePrimitiveArray::::try_new(ArrowDataType::Utf8, vec![], None).is_err()); +} + +#[test] +fn extend_from_self() { + let mut a = MutablePrimitiveArray::from([Some(1), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutablePrimitiveArray::from([Some(1), None, Some(1), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs new file mode 100644 index 000000000000..0cc32155a318 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs @@ -0,0 +1,53 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use either::Either; + +#[test] +fn array_to_mutable() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + + // to mutable push and freeze again + let mut mut_arr = arr.into_mut().unwrap_right(); + mut_arr.push(Some(5)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().as_slice(), [1, 2, 3, 5]); + + // let's cause a realloc and see if miri is ok + let mut mut_arr = immut.into_mut().unwrap_right(); + mut_arr.extend_constant(256, Some(9)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().len(), 256 + 4); +} + +#[test] +fn array_to_mutable_not_owned() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + let arr2 = arr.clone(); + + // to the `to_mutable` should fail and we should get back the original array + match arr2.into_mut() { + Either::Left(arr2) => { + assert_eq!(arr, arr2); + }, + _ => panic!(), + } +} + +#[test] +#[allow(clippy::redundant_clone)] +fn array_to_mutable_validity() { + let data = vec![1, 2, 3]; + + // both have a single reference should be ok + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.clone().into(), Some(bitmap)); + assert!(matches!(arr.into_mut(), Either::Right(_))); + + // now we clone the bitmap increasing the ref count + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), Some(bitmap.clone())); + assert!(matches!(arr.into_mut(), Either::Left(_))); +} diff --git a/crates/polars/tests/it/arrow/array/struct_/iterator.rs b/crates/polars/tests/it/arrow/array/struct_/iterator.rs new file mode 100644 index 000000000000..5b4b0b784d13 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/iterator.rs @@ -0,0 +1,28 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::scalar::new_scalar; + +#[test] +fn test_simple_iter() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", ArrowDataType::Boolean, false), + Field::new("c", ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + vec![boolean.clone(), int.clone()], + None, + ); + + for (i, item) in array.iter().enumerate() { + let expected = Some(vec![ + new_scalar(boolean.as_ref(), i), + new_scalar(int.as_ref(), i), + ]); + assert_eq!(expected, item); + } +} diff --git a/crates/polars/tests/it/arrow/array/struct_/mod.rs b/crates/polars/tests/it/arrow/array/struct_/mod.rs new file mode 100644 index 000000000000..ae1a0c0a37cb --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/mod.rs @@ -0,0 +1,27 @@ +mod iterator; +mod mutable; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::*; + +#[test] +fn debug() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", ArrowDataType::Boolean, false), + Field::new("c", ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + assert_eq!( + format!("{array:?}"), + "StructArray[{b: false, c: 42}, {b: false, c: 28}, None, {b: true, c: 31}]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/struct_/mutable.rs b/crates/polars/tests/it/arrow/array/struct_/mutable.rs new file mode 100644 index 000000000000..e9d698aa1bb3 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/mutable.rs @@ -0,0 +1,31 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn push() { + let c1 = Box::new(MutablePrimitiveArray::::new()) as Box; + let values = vec![c1]; + let data_type = ArrowDataType::Struct(vec![Field::new("f1", ArrowDataType::Int32, true)]); + let mut a = MutableStructArray::new(data_type, values); + + a.value::>(0) + .unwrap() + .push(Some(1)); + a.push(true); + a.value::>(0).unwrap().push(None); + a.push(false); + a.value::>(0) + .unwrap() + .push(Some(2)); + a.push(true); + + assert_eq!(a.len(), 3); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(a.is_valid(2)); + + assert_eq!( + a.value::>(0).unwrap().values(), + &Vec::from([1, 0, 2]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/union.rs b/crates/polars/tests/it/arrow/array/union.rs new file mode 100644 index 000000000000..b358aa8e44bb --- /dev/null +++ b/crates/polars/tests/it/arrow/array/union.rs @@ -0,0 +1,371 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}; +use polars_error::PolarsResult; + +fn next_unwrap(iter: &mut I) -> T +where + I: Iterator>, + T: Clone + 'static, +{ + iter.next() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone() +} + +#[test] +fn sparse_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields, None); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn dense_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + let offsets = Some(vec![0, 1, 0].into()); + + let array = UnionArray::new(data_type, types, fields, offsets); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::LargeUtf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type.clone(), types, fields.clone(), None); + + let result = array.sliced(1, 2); + + let sliced_types = Buffer::from(vec![0, 1]); + let sliced_fields = vec![ + Int32Array::from(&[None, Some(2)]).boxed(), + Utf8Array::::from([Some("b"), Some("c")]).boxed(), + ]; + let expected = UnionArray::new(data_type, sliced_types, sliced_fields, None); + + assert_eq!(expected, result); + Ok(()) +} + +#[test] +fn iter_sparse() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), None); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_sparse_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), None); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn scalar() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + + let scalar = new_scalar(&array, 0); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &Some(1) + ); + assert_eq!(union_scalar.type_(), 0); + let scalar = new_scalar(&array, 1); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &None + ); + assert_eq!(union_scalar.type_(), 0); + + let scalar = new_scalar(&array, 2); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some("c") + ); + assert_eq!(union_scalar.type_(), 1); + + Ok(()) +} + +#[test] +fn dense_without_offsets_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn fields_must_match() { + let fields = vec![ + Field::new("a", ArrowDataType::Int64, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn sparse_with_offsets_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + let offsets = vec![0, 1, 0].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn offsets_must_be_in_bounds() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length og types + let offsets = vec![0, 1].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn sparse_with_wrong_offsets1_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length of types + let offsets = vec![0, 1, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn types_must_be_in_bounds() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + // 10 > num fields + let types = vec![0, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mod.rs b/crates/polars/tests/it/arrow/array/utf8/mod.rs new file mode 100644 index 000000000000..fb75990dad29 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mod.rs @@ -0,0 +1,237 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some("hello"), None, Some("hello2")]; + + let array: Utf8Array = data.into_iter().collect(); + + assert_eq!(array.value(0), "hello"); + assert_eq!(array.value(1), ""); + assert_eq!(array.value(2), "hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, "hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Utf8Array::::new( + ArrowDataType::Utf8, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), ""); + assert_eq!(array.value(1), "hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn empty() { + let array = Utf8Array::::new_empty(ArrowDataType::Utf8); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_slice() { + let b = Utf8Array::::from_slice(["a", "b", "cc"]); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_iter_values() { + let b = Utf8Array::::from_iter_values(["a", "b", "cc"].iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_trusted_len_iter() { + let b = + Utf8Array::::from_trusted_len_iter(vec![Some("a"), Some("b"), Some("cc")].into_iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn try_from_trusted_len_iter() { + let b = Utf8Array::::try_from_trusted_len_iter( + vec![Some("a"), Some("b"), Some("cc")] + .into_iter() + .map(PolarsResult::Ok), + ) + .unwrap(); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150].into(); // invalid utf8 + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn not_utf8_individually() { + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = vec![207, 128].into(); // each is invalid utf8, but together is valid + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Int32, offsets, values, None).is_err()); +} + +#[test] +fn out_of_bounds_offsets_panics() { + // the 10 is out of bounds + let offsets = vec![0, 10, 11].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +fn debug() { + let array = Utf8Array::::from([Some("aa"), Some(""), None]); + + assert_eq!(format!("{array:?}"), "Utf8Array[aa, , None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" "), Some("hello")] + ); +} + +#[test] +fn iter_nth() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" "))); + assert_eq!(array.iter().nth(10), None); +} + +#[test] +fn test_apply_validity() { + let mut array = Utf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|bitmap| { + let mut mut_bitmap = bitmap.into_mut().right().unwrap(); + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap.into() + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable.rs b/crates/polars/tests/it/arrow/array/utf8/mutable.rs new file mode 100644 index 000000000000..8db873a90d10 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable.rs @@ -0,0 +1,242 @@ +use arrow::array::{MutableArray, MutableUtf8Array, TryExtendFromSelf, Utf8Array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacities() { + let b = MutableUtf8Array::::with_capacities(1, 10); + + assert!(b.values().capacity() >= 10); + assert!(b.offsets().capacity() >= 1); +} + +#[test] +fn push_null() { + let mut array = MutableUtf8Array::::new(); + array.push::<&str>(None); + + let array: Utf8Array = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push::<&str>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some("third".to_owned())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push(Some("fourth")); + for _ in 0..4 { + a.push(Some("aaaa")); + } + a.push(Some("こんにちは")); + + assert_eq!(a.pop(), Some("こんにちは".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("fourth".to_string())); + assert_eq!(a.pop(), Some("third".to_string())); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +/// Safety guarantee +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; // invalid utf8 + assert!(MutableUtf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![1, 2, 3, 4]; + assert!(MutableUtf8Array::::try_new(ArrowDataType::Int8, offsets, values, None).is_err()); +} + +#[test] +fn test_extend_trusted_len_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len_values(["hi", "there"].iter()); + array.extend_trusted_len_values(["hello"].iter()); + array.extend_trusted_len(vec![Some("again"), None].into_iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001111], 5)) + ); +} + +#[test] +fn test_extend_trusted_len() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 7, 12, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00011011], 5)) + ); +} + +#[test] +fn test_extend_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_values([Some("hi"), None, Some("there"), None].iter().flatten()); + array.extend_values([Some("hello"), None].iter().flatten()); + array.extend_values(vec![Some("again"), None].into_iter().flatten()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17]); + assert_eq!(array.validity(), None,); +} + +#[test] +fn test_extend() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + let array: Utf8Array = array.into(); + + assert_eq!( + array, + Utf8Array::::from([Some("hi"), None, Some("there"), None]) + ); +} + +#[test] +fn as_arc() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + assert_eq!( + Utf8Array::::from([Some("hi"), None, Some("there"), None]), + array.as_arc().as_ref() + ); +} + +#[test] +fn test_iter() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let result = array.iter().collect::>(); + assert_eq!( + result, + vec![ + Some("hi"), + Some("there"), + None, + Some("hello"), + Some("again"), + ] + ); +} + +#[test] +fn as_box_twice() { + let mut a = MutableUtf8Array::::new(); + let _ = a.as_box(); + let _ = a.as_box(); + let mut a = MutableUtf8Array::::new(); + let _ = a.as_arc(); + let _ = a.as_arc(); +} + +#[test] +fn extend_from_self() { + let mut a = MutableUtf8Array::::from([Some("aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableUtf8Array::::from([Some("aa"), None, Some("aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs new file mode 100644 index 000000000000..d4a309949934 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs @@ -0,0 +1,105 @@ +use arrow::array::{MutableArray, MutableUtf8ValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableUtf8ValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn data_type_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err()); +} + +#[test] +fn must_be_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; + assert!(std::str::from_utf8(&values).is_err()); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableUtf8ValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs new file mode 100644 index 000000000000..5f2624368bf7 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs @@ -0,0 +1,71 @@ +use arrow::array::Utf8Array; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; + +#[test] +fn not_shared() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/assign_ops.rs b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs new file mode 100644 index 000000000000..939133f0a5fe --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs @@ -0,0 +1,78 @@ +use arrow::bitmap::{binary_assign, unary_assign, Bitmap, MutableBitmap}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +#[test] +fn basics() { + let mut b = MutableBitmap::from_iter(std::iter::repeat(true).take(10)); + unary_assign(&mut b, |x: u8| !x); + assert_eq!( + b, + MutableBitmap::from_iter(std::iter::repeat(false).take(10)) + ); + + let mut b = MutableBitmap::from_iter(std::iter::repeat(true).take(10)); + let c = Bitmap::from_iter(std::iter::repeat(true).take(10)); + binary_assign(&mut b, &c, |x: u8, y| x | y); + assert_eq!( + b, + MutableBitmap::from_iter(std::iter::repeat(true).take(10)) + ); +} + +#[test] +fn binary_assign_oob() { + // this check we don't have an oob access if the bitmaps are size T + 1 + // and we do some slicing. + let a = MutableBitmap::from_iter(std::iter::repeat(true).take(65)); + let b = MutableBitmap::from_iter(std::iter::repeat(true).take(65)); + + let a: Bitmap = a.into(); + let a = a.sliced(10, 20); + + let b: Bitmap = b.into(); + let b = b.sliced(10, 20); + + let mut a = a.make_mut(); + + binary_assign(&mut a, &b, |x: u64, y| x & y); +} + +#[test] +fn fast_paths() { + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([false, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, true])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); +} + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(b in bitmap_strategy()) { + let not_b: MutableBitmap = b.iter().map(|x| !x).collect(); + + let mut b = b.make_mut(); + + unary_assign(&mut b, |x: u8| !x); + + assert_eq!(b, not_b); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs new file mode 100644 index 000000000000..e7fb3636e218 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs @@ -0,0 +1,40 @@ +use arrow::bitmap::{and, or, xor, Bitmap}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(bitmap in bitmap_strategy()) { + let not_bitmap: Bitmap = bitmap.iter().map(|x| !x).collect(); + + assert_eq!(!&bitmap, not_bitmap); + } +} + +#[test] +fn test_fast_paths() { + let all_true = Bitmap::from(&[true, true]); + let all_false = Bitmap::from(&[false, false]); + let toggled = Bitmap::from(&[true, false]); + + assert_eq!(and(&all_true, &all_true), all_true); + assert_eq!(and(&all_false, &all_true), all_false); + assert_eq!(and(&all_true, &all_false), all_false); + assert_eq!(and(&toggled, &all_false), all_false); + assert_eq!(and(&toggled, &all_true), toggled); + + assert_eq!(or(&all_true, &all_true), all_true); + assert_eq!(or(&all_true, &all_false), all_true); + assert_eq!(or(&all_false, &all_true), all_true); + assert_eq!(or(&all_false, &all_false), all_false); + assert_eq!(or(&toggled, &all_false), toggled); + + assert_eq!(xor(&all_true, &all_true), all_false); + assert_eq!(xor(&all_true, &all_false), all_true); + assert_eq!(xor(&all_false, &all_true), all_true); + assert_eq!(xor(&all_false, &all_false), all_false); + assert_eq!(xor(&toggled, &toggled), all_false); +} diff --git a/crates/polars/tests/it/arrow/bitmap/immutable.rs b/crates/polars/tests/it/arrow/bitmap/immutable.rs new file mode 100644 index 000000000000..29324c96d771 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/immutable.rs @@ -0,0 +1,67 @@ +use arrow::bitmap::Bitmap; + +#[test] +fn as_slice() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b11111111, 0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 9); +} + +#[test] +fn as_slice_offset() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + let b = b.sliced(8, 1); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 1); +} + +#[test] +fn as_slice_offset_middle() { + let b = Bitmap::from_u8_slice([0, 0, 0, 0b00010101], 27); + let b = b.sliced(22, 5); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0, 0b00010101]); + assert_eq!(offset, 6); + assert_eq!(length, 5); +} + +#[test] +fn debug() { + let b = Bitmap::from([true, true, false, true, true, true, true, true, true]); + let b = b.sliced(2, 7); + + assert_eq!(format!("{b:?}"), "[0b111110__, 0b_______1]"); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow() { + use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; + let buffer = arrow_buffer::Buffer::from_iter(vec![true, true, true, false, false, false, true]); + let bools = BooleanBuffer::new(buffer, 0, 7); + let nulls = NullBuffer::new(bools); + assert_eq!(nulls.null_count(), 3); + + let bitmap = Bitmap::from_null_buffer(nulls.clone()); + assert_eq!(nulls.null_count(), bitmap.unset_bits()); + assert_eq!(nulls.len(), bitmap.len()); + let back = NullBuffer::from(bitmap); + assert_eq!(nulls, back); + + let nulls = nulls.slice(1, 3); + assert_eq!(nulls.null_count(), 1); + assert_eq!(nulls.len(), 3); + + let bitmap = Bitmap::from_null_buffer(nulls.clone()); + assert_eq!(nulls.null_count(), bitmap.unset_bits()); + assert_eq!(nulls.len(), bitmap.len()); + let back = NullBuffer::from(bitmap); + assert_eq!(nulls, back); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mod.rs b/crates/polars/tests/it/arrow/bitmap/mod.rs new file mode 100644 index 000000000000..88758695b762 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mod.rs @@ -0,0 +1,124 @@ +mod assign_ops; +mod bitmap_ops; +mod immutable; +mod mutable; +mod utils; + +use arrow::bitmap::Bitmap; +use proptest::prelude::*; + +/// Returns a strategy of an arbitrary sliced [`Bitmap`] of size up to 1000 +pub(crate) fn bitmap_strategy() -> impl Strategy { + prop::collection::vec(any::(), 1..1000) + .prop_flat_map(|vec| { + let len = vec.len(); + (Just(vec), 0..len) + }) + .prop_flat_map(|(vec, index)| { + let len = vec.len(); + (Just(vec), Just(index), 0..len - index) + }) + .prop_flat_map(|(vec, index, len)| { + let bitmap = Bitmap::from(&vec); + let bitmap = bitmap.sliced(index, len); + Just(bitmap) + }) +} + +fn create_bitmap>(bytes: P, len: usize) -> Bitmap { + let buffer = Vec::::from(bytes.as_ref()); + Bitmap::from_u8_vec(buffer, len) +} + +#[test] +fn eq() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + assert!(lhs != rhs); +} + +#[test] +fn eq_len() { + let lhs = create_bitmap([0b01101010], 6); + let rhs = create_bitmap([0b00101010], 6); + assert!(lhs == rhs); + let rhs = create_bitmap([0b00001010], 6); + assert!(lhs != rhs); +} + +#[test] +fn eq_slice() { + let lhs = create_bitmap([0b10101010], 8).sliced(1, 7); + let rhs = create_bitmap([0b10101011], 8).sliced(1, 7); + assert!(lhs == rhs); + + let lhs = create_bitmap([0b10101010], 8).sliced(2, 6); + let rhs = create_bitmap([0b10101110], 8).sliced(2, 6); + assert!(lhs != rhs); +} + +#[test] +fn and() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01001010], 8); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or_large() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000010, 0b11111111, + ]; + let input1: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, + 0b10000000, 0b11111111, + ]; + let expected: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000010, 0b10000100, 0b10001000, 0b10010000, 0b10100000, + 0b11000010, 0b11111111, + ]; + + let lhs = create_bitmap(input, 62); + let rhs = create_bitmap(input1, 62); + let expected = create_bitmap(expected, 62); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn and_offset() { + let lhs = create_bitmap([0b01101011], 8).sliced(1, 7); + let rhs = create_bitmap([0b01001111], 8).sliced(1, 7); + let expected = create_bitmap([0b01001010], 8).sliced(1, 7); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01101110], 8); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn not() { + let lhs = create_bitmap([0b01101010], 6); + let expected = create_bitmap([0b00010101], 6); + assert_eq!(!&lhs, expected); +} + +#[test] +fn subslicing_gives_correct_null_count() { + let base = Bitmap::from([false, true, true, false, false, true, true, true]); + assert_eq!(base.unset_bits(), 3); + + let view1 = base.clone().sliced(0, 1); + let view2 = base.sliced(1, 7); + assert_eq!(view1.unset_bits(), 1); + assert_eq!(view2.unset_bits(), 2); + + let view3 = view2.sliced(0, 1); + assert_eq!(view3.unset_bits(), 0); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mutable.rs b/crates/polars/tests/it/arrow/bitmap/mutable.rs new file mode 100644 index 000000000000..af37d634a468 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mutable.rs @@ -0,0 +1,437 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; + +#[test] +fn from_slice() { + let slice = &[true, false, true]; + let a = MutableBitmap::from(slice); + assert_eq!(a.iter().collect::>(), slice); +} + +#[test] +fn from_len_zeroed() { + let a = MutableBitmap::from_len_zeroed(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 10); +} + +#[test] +fn from_len_set() { + let a = MutableBitmap::from_len_set(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 0); +} + +#[test] +fn try_new_invalid() { + assert!(MutableBitmap::try_new(vec![], 2).is_err()); +} + +#[test] +fn clear() { + let mut a = MutableBitmap::from_len_zeroed(10); + a.clear(); + assert_eq!(a.len(), 0); +} + +#[test] +fn trusted_len() { + let data = vec![true; 65]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 65); + + assert_eq!(bitmap.as_slice().0[8], 0b00000001); +} + +#[test] +fn trusted_len_small() { + let data = vec![true; 7]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 7); + + assert_eq!(bitmap.as_slice().0[0], 0b01111111); +} + +#[test] +fn push() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(false); + bitmap.push(false); + for _ in 0..7 { + bitmap.push(true) + } + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.as_slice().0, &[0b11111001, 0b00000011]); +} + +#[test] +fn push_small() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(true); + bitmap.push(false); + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice().0[0], 0b00000011); +} + +#[test] +fn push_exact_zeros() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(false) + } + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0.len(), 1); +} + +#[test] +fn push_exact_ones() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true) + } + let bitmap: Option = bitmap.into(); + assert!(bitmap.is_none()); +} + +#[test] +fn pop() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 2); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.as_slice().0[0], 0b00001010); +} + +#[test] +fn pop_large() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true); + } + + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 9); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 8); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0, &[0b11111111]); +} + +#[test] +fn pop_all() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(true); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 0); + assert_eq!(bitmap.pop(), None); + assert_eq!(bitmap.len(), 0); +} + +#[test] +fn capacity() { + let b = MutableBitmap::with_capacity(10); + assert!(b.capacity() >= 10); +} + +#[test] +fn capacity_push() { + let mut b = MutableBitmap::with_capacity(512); + (0..512).for_each(|_| b.push(true)); + assert_eq!(b.capacity(), 512); + b.reserve(8); + assert_eq!(b.capacity(), 1024); +} + +#[test] +fn extend() { + let mut b = MutableBitmap::new(); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + for (i, v) in b.iter().enumerate() { + assert_eq!(i % 6 == 0, v); + } +} + +#[test] +fn extend_offset() { + let mut b = MutableBitmap::new(); + b.push(true); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + let mut iter = b.iter().enumerate(); + assert!(iter.next().unwrap().1); + for (i, v) in iter { + assert_eq!((i - 1) % 6 == 0, v); + } +} + +#[test] +fn set() { + let mut bitmap = MutableBitmap::from_len_zeroed(12); + bitmap.set(0, true); + assert!(bitmap.get(0)); + bitmap.set(0, false); + assert!(!bitmap.get(0)); + + bitmap.set(11, true); + assert!(bitmap.get(11)); + bitmap.set(11, false); + assert!(!bitmap.get(11)); + bitmap.set(11, true); + + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 12); + assert_eq!(bitmap.as_slice().0[0], 0b00000000); +} + +#[test] +fn extend_from_bitmap() { + let other = Bitmap::from(&[true, false, true]); + let mut bitmap = MutableBitmap::new(); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice()[0], 0b00000101); + + // this call iterates over all bits + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 6); + assert_eq!(bitmap.as_slice()[0], 0b00101101); +} + +#[test] +fn extend_from_bitmap_offset() { + let other = Bitmap::from_u8_slice([0b00111111], 8); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 8); + assert_eq!(bitmap.as_slice(), &[1, 0, 0b11101010, 0b00001111]); + + // more than one byte + let other = Bitmap::from_u8_slice([0b00111111, 0b00001111, 0b0001100], 20); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 20); + assert_eq!( + bitmap.as_slice(), + &[1, 0, 0b11101010, 0b11001111, 0b0000011, 0b0000011] + ); +} + +#[test] +fn debug() { + let mut b = MutableBitmap::new(); + assert_eq!(format!("{b:?}"), "[]"); + b.push(true); + b.push(false); + assert_eq!(format!("{b:?}"), "[0b______01]"); + b.push(false); + b.push(false); + b.push(false); + b.push(false); + b.push(true); + b.push(true); + assert_eq!(format!("{b:?}"), "[0b11000001]"); + b.push(true); + assert_eq!(format!("{b:?}"), "[0b11000001, 0b_______1]"); +} + +#[test] +fn extend_set() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b11111111]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b01111110]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(9, true); + assert_eq!(b.as_slice(), &[0b11111110, 0b11111111]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(2, true); + assert_eq!(b.as_slice(), &[0b00110000]); + assert_eq!(b.len(), 4 + 2); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(8, true); + assert_eq!(b.as_slice(), &[0b11110000, 0b11111111]); + assert_eq!(b.len(), 4 + 8); + + let mut b = MutableBitmap::from(&[true, true]); + b.extend_constant(3, true); + assert_eq!(b.as_slice(), &[0b00011111]); + assert_eq!(b.len(), 2 + 3); +} + +#[test] +fn extend_unset() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b0000000]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b00000001]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(9, false); + assert_eq!(b.as_slice(), &[0b0000001, 0b00000000]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_constant(2, false); + assert_eq!(b.as_slice(), &[0b00001111]); + assert_eq!(b.len(), 4 + 2); +} + +#[test] +fn extend_bitmap() { + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001], 0, 6); + assert_eq!(b.as_slice(), &[0b00110011]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b00110011, 0b00110010]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b10011111, 0b10010001]); + assert_eq!(b.len(), 4 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true, true]); + b.extend_from_slice(&[0b00001011], 0, 4); + assert_eq!(b.as_slice(), &[0b01111111, 0b00000001]); + assert_eq!(b.len(), 5 + 4); +} + +// TODO! undo miri ignore once issue is fixed in miri +// this test was a memory hog and lead to OOM in CI +// given enough memory it was able to pass successfully on a local +#[test] +#[cfg_attr(miri, ignore)] +fn extend_constant1() { + use std::iter::FromIterator; + for i in 0..64 { + for j in 0..64 { + let mut b = MutableBitmap::new(); + b.extend_constant(i, false); + b.extend_constant(j, true); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat(false) + .take(i) + .chain(std::iter::repeat(true).take(j)) + ) + ); + + let mut b = MutableBitmap::new(); + b.extend_constant(i, true); + b.extend_constant(j, false); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat(true) + .take(i) + .chain(std::iter::repeat(false).take(j)) + ) + ); + } + } +} + +#[test] +fn extend_bitmap_one() { + for offset in 0..7 { + let mut b = MutableBitmap::new(); + for _ in 0..4 { + b.extend_from_slice(&[!0], offset, 1); + b.extend_from_slice(&[!0], offset, 1); + } + assert_eq!(b.as_slice(), &[0b11111111]); + } +} + +#[test] +fn extend_bitmap_other() { + let mut a = MutableBitmap::from([true, true, true, false, true, true, true, false, true, true]); + a.extend_from_slice(&[0b01111110u8, 0b10111111, 0b11011111, 0b00000111], 20, 2); + assert_eq!( + a, + MutableBitmap::from([ + true, true, true, false, true, true, true, false, true, true, true, false + ]) + ); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(1025); + a.push(false); + a.shrink_to_fit(); + assert!(a.capacity() < 1025); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs new file mode 100644 index 000000000000..104db7fdc3bb --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs @@ -0,0 +1,33 @@ +use arrow::bitmap::utils::BitChunksExact; + +#[test] +fn basics() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next().unwrap(), 0b11111111u8); + assert_eq!(iter.remainder(), 0b00000001u8); +} + +#[test] +fn basics_u16_small() { + let mut iter = BitChunksExact::::new(&[0b11111111u8], 7); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_1111_1111u16); +} + +#[test] +fn basics_u16() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0001_1111_1111u16); +} + +#[test] +fn remainder_u16() { + let mut iter = BitChunksExact::::new( + &[0b11111111u8, 0b00000001u8, 0b00000001u8, 0b11011011u8], + 23, + ); + assert_eq!(iter.next(), Some(511)); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 1u16); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs new file mode 100644 index 000000000000..d19b6e51b5ed --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs @@ -0,0 +1,163 @@ +use arrow::bitmap::utils::BitChunks; +use arrow::types::BitChunkIter; + +#[test] +fn basics() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000010u8], 0, 16); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000100u8], 0, 18); + assert_eq!(a.remainder(), 0b00000100u16); +} + +#[test] +fn remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000010u8], 0, 18); + assert_eq!(a.remainder(), 0b0000_0000_0000_0010u16); +} + +#[test] +fn basics_offset() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000001u8], 1, 16); + assert_eq!(iter.remainder(), 0); + assert_eq!(iter.next().unwrap(), 0b1000_0001_1000_0000u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn basics_offset_remainder() { + let mut a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b10000001u8], 1, 15); + assert_eq!(a.next(), None); + assert_eq!(a.remainder(), 0b1000_0001_1000_0000u16); + assert_eq!(a.remainder_len(), 15); +} + +#[test] +fn offset_remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000011u8], 1, 17); + assert_eq!(a.remainder(), 0b0000_0000_0000_0001u16); +} + +#[test] +fn offset_remainder_saturating2() { + let a = BitChunks::::new(&[0b01001001u8, 0b00000001], 1, 8); + assert_eq!(a.remainder(), 0b1010_0100u64); +} + +#[test] +fn offset_remainder_saturating3() { + let input: &[u8] = &[0b01000000, 0b01000001]; + let a = BitChunks::::new(input, 8, 2); + assert_eq!(a.remainder(), 0b0100_0001u64); +} + +#[test] +fn basics_multiple() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 0, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn basics_multiple_offset() { + let mut iter = BitChunks::::new( + &[ + 0b00000001u8, + 0b00000010u8, + 0b00000100u8, + 0b00001000u8, + 0b00000001u8, + ], + 1, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0001_0000_0000u16); + assert_eq!(iter.next().unwrap(), 0b1000_0100_0000_0010u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder_large() { + let input: &[u8] = &[ + 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00000100, + ]; + let mut iter = BitChunks::::new(input, 0, 8 * 12 + 4); + assert_eq!(iter.remainder_len(), 100 - 96); + + for j in 0..12 { + let mut a = BitChunkIter::new(iter.next().unwrap(), 8); + for i in 0..8 { + assert_eq!(a.next().unwrap(), (j * 8 + i + 1) % 3 == 0); + } + } + assert_eq!(None, iter.next()); + + let expected_remainder = 0b00000100u8; + assert_eq!(iter.remainder(), expected_remainder); + + let mut a = BitChunkIter::new(expected_remainder, 8); + for i in 0..4 { + assert_eq!(a.next().unwrap(), (i + 1) % 3 == 0); + } +} + +#[test] +fn basics_1() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 8, + 3 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0100_0000_0010u16); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_0000_1000u16); + assert_eq!(iter.remainder_len(), 8); +} + +#[test] +fn basics_2() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 7, + 3 * 8, + ); + assert_eq!(iter.remainder(), 0b0000_0000_0001_0000u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn remainder_1() { + let mut iter = BitChunks::::new(&[0b11111111u8, 0b00000001u8], 0, 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b1_1111_1111u64); +} + +#[test] +fn remainder_2() { + // (i % 3 == 0) in bitmap + let input: &[u8] = &[ + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, + 0b10010010, 0b00100100, 0b01001001, /* 73 */ + 0b10010010, /* 146 */ + 0b00100100, 0b00001001, + ]; + let offset = 10; // 8 + 2 + let length = 90; + + let mut iter = BitChunks::::new(input, offset, length); + let first: u64 = 0b0100100100100100100100100100100100100100100100100100100100100100; + assert_eq!(first, iter.next().unwrap()); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b10010010010010010010010010u64); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs new file mode 100644 index 000000000000..b07c50db9011 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs @@ -0,0 +1,40 @@ +use arrow::bitmap::utils::fmt; + +struct A<'a>(&'a [u8], usize, usize); + +impl<'a> std::fmt::Debug for A<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(self.0, self.1, self.2, f) + } +} + +#[test] +fn test_debug() -> std::fmt::Result { + assert_eq!(format!("{:?}", A(&[1], 0, 0)), "[]"); + assert_eq!(format!("{:?}", A(&[0b11000001], 0, 8)), "[0b11000001]"); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1], 0, 9)), + "[0b11000001, 0b_______1]" + ); + assert_eq!(format!("{:?}", A(&[1], 0, 2)), "[0b______01]"); + assert_eq!(format!("{:?}", A(&[1], 1, 2)), "[0b_____00_]"); + assert_eq!(format!("{:?}", A(&[1], 2, 2)), "[0b____00__]"); + assert_eq!(format!("{:?}", A(&[1], 3, 2)), "[0b___00___]"); + assert_eq!(format!("{:?}", A(&[1], 4, 2)), "[0b__00____]"); + assert_eq!(format!("{:?}", A(&[1], 5, 2)), "[0b_00_____]"); + assert_eq!(format!("{:?}", A(&[1], 6, 2)), "[0b00______]"); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1], 1, 9)), + "[0b1100000_, 0b______01]" + ); + // extra bytes are ignored + assert_eq!( + format!("{:?}", A(&[0b11000001, 1, 1, 1], 1, 9)), + "[0b1100000_, 0b______01]" + ); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1, 1], 2, 16)), + "[0b110000__, 0b00000001, 0b______01]" + ); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs new file mode 100644 index 000000000000..184a428f137b --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs @@ -0,0 +1,44 @@ +use arrow::bitmap::utils::BitmapIter; + +#[test] +fn basic() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 0, 6); + let result = iter.collect::>(); + assert_eq!(result, vec![true, true, false, true, true, false]) +} + +#[test] +fn large() { + let values = &[0b01011011u8]; + let values = std::iter::repeat(values) + .take(63) + .flatten() + .copied() + .collect::>(); + let len = 63 * 8; + let iter = BitmapIter::new(&values, 0, len); + assert_eq!(iter.count(), len); +} + +#[test] +fn offset() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 2, 4); + let result = iter.collect::>(); + assert_eq!(result, vec![false, true, true, false]) +} + +#[test] +fn rev() { + let values = &[0b01011011u8, 0b01011011u8]; + let iter = BitmapIter::new(values, 2, 13); + let result = iter.rev().collect::>(); + assert_eq!( + result, + vec![false, true, true, false, true, false, true, true, false, true, true, false, true] + .into_iter() + .rev() + .collect::>() + ) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/mod.rs b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs new file mode 100644 index 000000000000..12af43e4e949 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs @@ -0,0 +1,83 @@ +use arrow::bitmap::utils::*; +use proptest::prelude::*; + +use super::bitmap_strategy; + +mod bit_chunks_exact; +mod chunk_iter; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +#[test] +fn get_bit_basics() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + for i in 0..8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 8)); + for i in 8 + 1..2 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 2 * 8 + 1)); + for i in 2 * 8 + 2..3 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 3 * 8 + 2)); + for i in 3 * 8 + 3..4 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 4 * 8 + 3)); +} + +#[test] +fn count_zeros_basics() { + let input: &[u8] = &[ + 0b01001001, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + assert_eq!(count_zeros(input, 0, 8), 8 - 3); + assert_eq!(count_zeros(input, 1, 7), 7 - 2); + assert_eq!(count_zeros(input, 1, 8), 8 - 3); + assert_eq!(count_zeros(input, 2, 7), 7 - 3); + assert_eq!(count_zeros(input, 0, 32), 32 - 6); + assert_eq!(count_zeros(input, 9, 2), 2); + + let input: &[u8] = &[0b01000000, 0b01000001]; + assert_eq!(count_zeros(input, 8, 2), 1); + assert_eq!(count_zeros(input, 8, 3), 2); + assert_eq!(count_zeros(input, 8, 4), 3); + assert_eq!(count_zeros(input, 8, 5), 4); + assert_eq!(count_zeros(input, 8, 6), 5); + assert_eq!(count_zeros(input, 8, 7), 5); + assert_eq!(count_zeros(input, 8, 8), 6); + + let input: &[u8] = &[0b01000000, 0b01010101]; + assert_eq!(count_zeros(input, 9, 2), 1); + assert_eq!(count_zeros(input, 10, 2), 1); + assert_eq!(count_zeros(input, 11, 2), 1); + assert_eq!(count_zeros(input, 12, 2), 1); + assert_eq!(count_zeros(input, 13, 2), 1); + assert_eq!(count_zeros(input, 14, 2), 1); +} + +#[test] +fn count_zeros_1() { + // offset = 10, len = 90 => remainder + let input: &[u8] = &[73, 146, 36, 73, 146, 36, 73, 146, 36, 73, 146, 36, 9]; + assert_eq!(count_zeros(input, 10, 90), 60); +} + +proptest! { + /// Asserts that `Bitmap::null_count` equals the number of unset bits + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn null_count(bitmap in bitmap_strategy()) { + let sum_of_sets: usize = (0..bitmap.len()).map(|x| (!bitmap.get_bit(x)) as usize).sum(); + assert_eq!(bitmap.unset_bits(), sum_of_sets); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs new file mode 100644 index 000000000000..4a0d024643ec --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs @@ -0,0 +1,150 @@ +use arrow::bitmap::utils::SlicesIterator; +use arrow::bitmap::Bitmap; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that: + /// * `slots` is the number of set bits in the bitmap + /// * the sum of the lens of the slices equals `slots` + /// * each item on each slice is set + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn check_invariants(bitmap in bitmap_strategy()) { + let iter = SlicesIterator::new(&bitmap); + + let slots = iter.slots(); + + assert_eq!(bitmap.len() - bitmap.unset_bits(), slots); + + let slices = iter.collect::>(); + let mut sum = 0; + for (start, len) in slices { + sum += len; + for i in start..(start+len) { + assert!(bitmap.get_bit(i)); + } + } + assert_eq!(sum, slots); + } +} + +#[test] +fn single_set() { + let values = (0..16).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn single_unset() { + let values = (0..64).map(|i| i != 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 1), (2, 62)]); + assert_eq!(count, 64 - 1); +} + +#[test] +fn generic() { + let values = (0..130).map(|i| i % 62 != 0).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 61), (63, 61), (125, 5)]); + assert_eq!(count, 61 + 61 + 5); +} + +#[test] +fn incomplete_byte() { + let values = (0..6).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn incomplete_byte1() { + let values = (0..12).map(|i| i == 9).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(9, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn end_of_byte() { + let values = (0..16).map(|i| i != 7).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 7), (8, 8)]); + assert_eq!(count, 15); +} + +#[test] +fn bla() { + let values = vec![true, true, true, true, true, true, true, false] + .into_iter() + .collect::(); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn past_end_should_not_be_returned() { + let values = Bitmap::from_u8_slice([0b11111010], 3); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn sliced() { + let values = Bitmap::from_u8_slice([0b11111010, 0b11111011], 16); + let values = values.sliced(8, 2); + let iter = SlicesIterator::new(&values); + + let chunks = iter.collect::>(); + + // the first "11" in the second byte + assert_eq!(chunks, vec![(0, 2)]); +} + +#[test] +fn remainder_1() { + let values = Bitmap::from_u8_slice([0, 0, 0b00000000, 0b00010101], 27); + let values = values.sliced(22, 5); + let iter = SlicesIterator::new(&values); + let chunks = iter.collect::>(); + assert_eq!(chunks, vec![(2, 1), (4, 1)]); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs new file mode 100644 index 000000000000..a12dedaa43d9 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs @@ -0,0 +1,106 @@ +use arrow::bitmap::utils::{BitmapIter, ZipValidity}; +use arrow::bitmap::Bitmap; + +#[test] +fn basic() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let values = vec![0, 1]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), None]); +} + +#[test] +fn complete() { + let a = Bitmap::from([true, false, true, false, true, false, true, false]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![Some(0), None, Some(2), None, Some(4), None, Some(6), None] + ); +} + +#[test] +fn slices() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let offsets = [0, 2, 3]; + let values = [1, 2, 3]; + let iter = offsets.windows(2).map(|x| { + let start = x[0]; + let end = x[1]; + &values[start..end] + }); + let zip = ZipValidity::new(iter, a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some([1, 2].as_ref()), None]); +} + +#[test] +fn byte() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![ + Some(0), + None, + Some(2), + None, + None, + Some(5), + Some(6), + None, + Some(8) + ] + ); +} + +#[test] +fn offset() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + ); +} + +#[test] +fn none() { + let values = vec![0, 1, 2]; + let zip = ZipValidity::new(values.into_iter(), None::); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), Some(1), Some(2)]); +} + +#[test] +fn rev() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let result = zip.rev().collect::>(); + let expected = vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + .into_iter() + .rev() + .collect::>(); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/buffer/immutable.rs b/crates/polars/tests/it/arrow/buffer/immutable.rs new file mode 100644 index 000000000000..aaf16ad8fa87 --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/immutable.rs @@ -0,0 +1,119 @@ +use arrow::buffer::Buffer; + +#[test] +fn new() { + let buffer = Buffer::::new(); + assert_eq!(buffer.len(), 0); + assert!(buffer.is_empty()); +} + +#[test] +fn from_slice() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn slice() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + assert_eq!(buffer.len(), 2); + assert_eq!(buffer.as_slice(), &[1, 2]); +} + +#[test] +fn from_iter() { + let buffer = (0..3).collect::>(); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn debug() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + let a = format!("{buffer:?}"); + assert_eq!(a, "[1, 2]") +} + +#[test] +fn from_vec() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 3); + assert_eq!(b.as_slice(), &[1, 2, 3]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = buffer.slice(4); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 2); + assert_eq!(b.as_slice(), &[2, 3]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i64, 2_i64]); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 4); + assert_eq!(b.as_slice(), &[1, 0, 2, 0]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = buffer.slice(4); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 3); + assert_eq!(b.as_slice(), &[0, 2, 0]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow_vec() { + // Zero-copy vec conversion in arrow-rs + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let back: Vec = buffer.into_vec().unwrap(); + + // Zero-copy vec conversion in arrow2 + let buffer = Buffer::::from(back); + let back: Vec = buffer.into_mut().unwrap_right(); + + let buffer = arrow_buffer::Buffer::from_vec(back); + let buffer = Buffer::::from(buffer); + + // But not possible after conversion between buffer representations + let _ = buffer.into_mut().unwrap_left(); + + let buffer = Buffer::::from(vec![1_i32]); + let buffer = arrow_buffer::Buffer::from(buffer); + + // But not possible after conversion between buffer representations + let _ = buffer.into_vec::().unwrap_err(); +} + +#[test] +#[cfg(feature = "arrow")] +#[should_panic(expected = "not aligned")] +fn from_arrow_misaligned() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]).slice(1); + let _ = Buffer::::from(buffer); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow_sliced() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let b = Buffer::::from(buffer); + let sliced = b.sliced(1, 2); + let back = arrow_buffer::Buffer::from(sliced); + assert_eq!(back.typed_data::(), &[2, 3]); +} diff --git a/crates/polars/tests/it/arrow/buffer/mod.rs b/crates/polars/tests/it/arrow/buffer/mod.rs new file mode 100644 index 000000000000..723312cd1a87 --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/mod.rs @@ -0,0 +1 @@ +mod immutable; diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs new file mode 100644 index 000000000000..3f31240b8602 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -0,0 +1,32 @@ +use arrow::array::*; +use arrow::compute::aggregate::estimated_bytes_size; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5]); + assert_eq!(5 * std::mem::size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn boolean() { + let a = BooleanArray::from_slice([true]); + assert_eq!(1, estimated_bytes_size(&a)); +} + +#[test] +fn utf8() { + let a = Utf8Array::::from_slice(["aaa"]); + assert_eq!(3 + 2 * std::mem::size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn fixed_size_list() { + let data_type = ArrowDataType::FixedSizeList( + Box::new(Field::new("elem", ArrowDataType::Float32, false)), + 3, + ); + let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let a = FixedSizeListArray::new(data_type, values, None); + assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); +} diff --git a/crates/polars/tests/it/arrow/compute/aggregate/mod.rs b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs new file mode 100644 index 000000000000..d7de8a8c37c5 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs @@ -0,0 +1,2 @@ +mod memory; +mod sum; diff --git a/crates/polars/tests/it/arrow/compute/aggregate/sum.rs b/crates/polars/tests/it/arrow/compute/aggregate/sum.rs new file mode 100644 index 000000000000..011f75aad356 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/sum.rs @@ -0,0 +1,37 @@ +use arrow::array::*; +use arrow::compute::aggregate::{sum, sum_primitive}; +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{PrimitiveScalar, Scalar}; + +#[test] +fn test_primitive_array_sum() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5]); + assert_eq!( + &PrimitiveScalar::::from(Some(15)) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); + + let a = a.to(ArrowDataType::Date32); + assert_eq!( + &PrimitiveScalar::::from(Some(15)).to(ArrowDataType::Date32) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); +} + +#[test] +fn test_primitive_array_float_sum() { + let a = Float64Array::from_slice([1.1f64, 2.2, 3.3, 4.4, 5.5]); + assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON); +} + +#[test] +fn test_primitive_array_sum_with_nulls() { + let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]); + assert_eq!(10, sum_primitive(&a).unwrap()); +} + +#[test] +fn test_primitive_array_sum_all_nulls() { + let a = Int32Array::from(&[None, None, None]); + assert_eq!(None, sum_primitive(&a)); +} diff --git a/crates/polars/tests/it/arrow/compute/arity_assign.rs b/crates/polars/tests/it/arrow/compute/arity_assign.rs new file mode 100644 index 000000000000..b8ba89dda238 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/arity_assign.rs @@ -0,0 +1,21 @@ +use arrow::array::Int32Array; +use arrow::compute::arity_assign::{binary, unary}; + +#[test] +fn test_unary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + + unary(&mut a, |x| x + 10); + + assert_eq!(a, Int32Array::from([Some(15), Some(16), None, Some(20)])) +} + +#[test] +fn test_binary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + let b = Int32Array::from([Some(1), Some(2), Some(1), None]); + + binary(&mut a, &b, |x, y| x + y); + + assert_eq!(a, Int32Array::from([Some(6), Some(8), None, None])) +} diff --git a/crates/polars/tests/it/arrow/compute/bitwise.rs b/crates/polars/tests/it/arrow/compute/bitwise.rs new file mode 100644 index 000000000000..e2a380fbd707 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/bitwise.rs @@ -0,0 +1,41 @@ +use arrow::array::*; +use arrow::compute::bitwise::*; + +#[test] +fn test_xor() { + let a = Int32Array::from(&[Some(2), Some(4), Some(6), Some(7)]); + let b = Int32Array::from(&[None, Some(6), Some(9), Some(7)]); + let result = xor(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(15), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_and() { + let a = Int32Array::from(&[Some(1), Some(2), Some(15)]); + let b = Int32Array::from(&[None, Some(2), Some(6)]); + let result = and(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(6)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_or() { + let a = Int32Array::from(&[Some(1), Some(2), Some(0)]); + let b = Int32Array::from(&[None, Some(2), Some(0)]); + let result = or(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_not() { + let a = Int8Array::from(&[None, Some(1i8), Some(-100i8)]); + let result = not(&a); + let expected = Int8Array::from(&[None, Some(-2), Some(99)]); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean.rs b/crates/polars/tests/it/arrow/compute/boolean.rs new file mode 100644 index 000000000000..488a53b4732d --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean.rs @@ -0,0 +1,453 @@ +use std::iter::FromIterator; + +use arrow::array::*; +use arrow::compute::boolean::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn array_and() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = or(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, true, true, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(true), + None, + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_not() { + let a = BooleanArray::from_slice(vec![false, true]); + let c = not(&a); + + let expected = BooleanArray::from_slice(vec![true, false]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_sliced_same_offset() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(8, 4); + let b = b.sliced(8, 4); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_same_offset_mod8() { + let a = BooleanArray::from_slice(vec![ + false, false, true, true, false, false, false, false, false, false, false, false, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(0, 4); + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset1() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + + let a = a.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset2() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_validity_offset() { + let a = BooleanArray::from(vec![None, Some(false), Some(true), None, Some(true)]); + let a = a.sliced(1, 4); + let a = a.as_any().downcast_ref::().unwrap(); + + let b = BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(true), + Some(true), + ]); + + let b = b.sliced(2, 4); + let b = b.as_any().downcast_ref::().unwrap(); + + let c = and(a, b); + + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + + assert_eq!(expected, c); +} + +#[test] +fn test_nonnull_array_is_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_null() { + let a = Int32Array::from_slice(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice([true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_not_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_not_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn array_and_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, false, false]); + + assert_eq!(real, expected); +} + +#[test] +fn array_and_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([true, true, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(true), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn test_any_all() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(false), Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(true), Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat(false).take(10).map(Some)); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat(true).take(10).map(Some)); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter([true, false, true, true].map(Some)); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from(&[Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[]); + assert!(!any(&array)); + assert!(all(&array)); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean_kleene.rs b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs new file mode 100644 index 000000000000..515490796d38 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs @@ -0,0 +1,223 @@ +use arrow::array::BooleanArray; +use arrow::compute::boolean_kleene::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn and_generic() { + let lhs = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let rhs = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&lhs, &rhs); + + let expected = BooleanArray::from(&[ + None, + Some(false), + None, + Some(false), + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_generic() { + let a = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + None, + None, + Some(true), + None, + Some(false), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_right_nulls() { + let a = BooleanArray::from_slice([false, false, false, true, true, true]); + + let b = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_left_nulls() { + let a = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(false), + None, + ]); + + let b = BooleanArray::from_slice([false, false, false, true, true, true]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = and_scalar(&array, &scalar); + + // Should be same as argument array if scalar is true. + assert_eq!(result, array); +} + +#[test] +fn array_and_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_and_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), None, None, Some(false), None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = or_scalar(&array, &scalar); + + // Should be same as argument array if scalar is false. + assert_eq!(result, array); +} + +#[test] +fn array_or_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[Some(true), None, None, Some(true), None, None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_empty() { + let array = BooleanArray::from(&[]); + assert_eq!(any(&array), Some(false)); + assert_eq!(all(&array), Some(true)); +} diff --git a/crates/polars/tests/it/arrow/compute/if_then_else.rs b/crates/polars/tests/it/arrow/compute/if_then_else.rs new file mode 100644 index 000000000000..e203d831c39f --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/if_then_else.rs @@ -0,0 +1,42 @@ +use arrow::array::*; +use arrow::compute::if_then_else::if_then_else; +use polars_error::PolarsResult; + +#[test] +fn basics() -> PolarsResult<()> { + let lhs = Int32Array::from_slice([1, 2, 3]); + let rhs = Int32Array::from_slice([4, 5, 6]); + let predicate = BooleanArray::from_slice(vec![true, false, true]); + let c = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from_slice([1, 5, 3]); + + assert_eq!(expected, c.as_ref()); + Ok(()) +} + +#[test] +fn basics_nulls() -> PolarsResult<()> { + let lhs = Int32Array::from(&[Some(1), None, None]); + let rhs = Int32Array::from(&[None, Some(5), Some(6)]); + let predicate = BooleanArray::from_slice(vec![true, false, true]); + let c = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from(&[Some(1), Some(5), None]); + + assert_eq!(expected, c.as_ref()); + Ok(()) +} + +#[test] +fn basics_nulls_pred() -> PolarsResult<()> { + let lhs = Int32Array::from_slice([1, 2, 3]); + let rhs = Int32Array::from_slice([4, 5, 6]); + let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); + let result = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from(&[Some(1), None, Some(6)]); + + assert_eq!(expected, result.as_ref()); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/compute/mod.rs b/crates/polars/tests/it/arrow/compute/mod.rs new file mode 100644 index 000000000000..95126a4a3a54 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/mod.rs @@ -0,0 +1,12 @@ +#[cfg(feature = "compute_aggregate")] +mod aggregate; +#[cfg(feature = "compute_bitwise")] +mod bitwise; +#[cfg(feature = "compute_boolean")] +mod boolean; +#[cfg(feature = "compute_boolean_kleene")] +mod boolean_kleene; +#[cfg(feature = "compute_if_then_else")] +mod if_then_else; + +mod arity_assign; diff --git a/crates/polars-arrow/tests/it/ffi/data.rs b/crates/polars/tests/it/arrow/ffi/data.rs similarity index 94% rename from crates/polars-arrow/tests/it/ffi/data.rs rename to crates/polars/tests/it/arrow/ffi/data.rs index 1b5fc86922c0..bb798a1bc4fc 100644 --- a/crates/polars-arrow/tests/it/ffi/data.rs +++ b/crates/polars/tests/it/arrow/ffi/data.rs @@ -1,6 +1,6 @@ -use polars_arrow::array::*; -use polars_arrow::datatypes::Field; -use polars_arrow::ffi; +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; use polars_error::PolarsResult; fn _test_round_trip(array: Box, expected: Box) -> PolarsResult<()> { diff --git a/crates/polars/tests/it/arrow/ffi/mod.rs b/crates/polars/tests/it/arrow/ffi/mod.rs new file mode 100644 index 000000000000..1ca8fa75c400 --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/mod.rs @@ -0,0 +1,3 @@ +mod data; + +mod stream; diff --git a/crates/polars/tests/it/arrow/ffi/stream.rs b/crates/polars/tests/it/arrow/ffi/stream.rs new file mode 100644 index 000000000000..f949fdf4c88e --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/stream.rs @@ -0,0 +1,44 @@ +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; +use polars_error::{PolarsError, PolarsResult}; + +fn _test_round_trip(arrays: Vec>) -> PolarsResult<()> { + let field = Field::new("a", arrays[0].data_type().clone(), true); + let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; + + let mut stream = Box::new(ffi::ArrowArrayStream::empty()); + + *stream = ffi::export_iterator(iter, field.clone()); + + // import + let mut stream = unsafe { ffi::ArrowArrayStreamReader::try_new(stream)? }; + + let mut produced_arrays: Vec> = vec![]; + while let Some(array) = unsafe { stream.next() } { + produced_arrays.push(array?); + } + + assert_eq!(produced_arrays, arrays); + assert_eq!(stream.field(), &field); + Ok(()) +} + +#[test] +fn round_trip() -> PolarsResult<()> { + let array = Int32Array::from(&[Some(2), None, Some(1), None]); + let array: Box = Box::new(array); + + _test_round_trip(vec![array.clone(), array.clone(), array]) +} + +#[test] +fn stream_reader_try_new_invalid_argument_error_on_released_stream() { + let released_stream = Box::new(ffi::ArrowArrayStream::empty()); + let reader = unsafe { ffi::ArrowArrayStreamReader::try_new(released_stream) }; + // poor man's assert_matches: + match reader { + Err(PolarsError::InvalidOperation(_)) => {}, + _ => panic!("ArrowArrayStreamReader::try_new did not return an InvalidArgumentError"), + } +} diff --git a/crates/polars-arrow/tests/it/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs similarity index 89% rename from crates/polars-arrow/tests/it/io/ipc/mod.rs rename to crates/polars/tests/it/arrow/io/ipc/mod.rs index 202eaf0cdfb2..c55b346e9702 100644 --- a/crates/polars-arrow/tests/it/io/ipc/mod.rs +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -1,12 +1,12 @@ use std::io::Cursor; use std::sync::Arc; -use polars_arrow::array::*; -use polars_arrow::chunk::Chunk; -use polars_arrow::datatypes::{ArrowSchema, ArrowSchemaRef, Field}; -use polars_arrow::io::ipc::read::{read_file_metadata, FileReader}; -use polars_arrow::io::ipc::write::*; -use polars_arrow::io::ipc::IpcField; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::{ArrowSchema, ArrowSchemaRef, Field}; +use arrow::io::ipc::read::{read_file_metadata, FileReader}; +use arrow::io::ipc::write::*; +use arrow::io::ipc::IpcField; use polars_error::*; pub(crate) fn write( diff --git a/crates/polars-arrow/tests/it/io/mod.rs b/crates/polars/tests/it/arrow/io/mod.rs similarity index 100% rename from crates/polars-arrow/tests/it/io/mod.rs rename to crates/polars/tests/it/arrow/io/mod.rs diff --git a/crates/polars/tests/it/arrow/mod.rs b/crates/polars/tests/it/arrow/mod.rs new file mode 100644 index 000000000000..f9f3ef3d2ac9 --- /dev/null +++ b/crates/polars/tests/it/arrow/mod.rs @@ -0,0 +1,12 @@ +mod ffi; +#[cfg(feature = "io_ipc_compression")] +mod io; + +mod scalar; + +mod array; +mod bitmap; + +mod buffer; + +mod compute; diff --git a/crates/polars/tests/it/arrow/scalar/binary.rs b/crates/polars/tests/it/arrow/scalar/binary.rs new file mode 100644 index 000000000000..d1b3e984d379 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/binary.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BinaryScalar::::from(Some("a")); + let b = BinaryScalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BinaryScalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BinaryScalar::::from(Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.data_type(), &ArrowDataType::Binary); + assert!(a.is_valid()); + + let a = BinaryScalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &ArrowDataType::LargeBinary); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/boolean.rs b/crates/polars/tests/it/arrow/scalar/boolean.rs new file mode 100644 index 000000000000..7c400b0fde3e --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/boolean.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BooleanScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BooleanScalar::from(Some(true)); + let b = BooleanScalar::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BooleanScalar::from(Some(false)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BooleanScalar::new(Some(true)); + + assert_eq!(a.value(), Some(true)); + assert_eq!(a.data_type(), &ArrowDataType::Boolean); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs new file mode 100644 index 000000000000..c83bc4d69749 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{FixedSizeBinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(1)); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs new file mode 100644 index 000000000000..2aa6f45bbd74 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs @@ -0,0 +1,43 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{FixedSizeListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + let b = FixedSizeListScalar::new(dt.clone(), None); + + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + + let b = FixedSizeListScalar::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!( + BooleanArray::from_slice([true, false]), + a.values().unwrap().as_ref() + ); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/list.rs b/crates/polars/tests/it/arrow/scalar/list.rs new file mode 100644 index 000000000000..7cd2938237c9 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/list.rs @@ -0,0 +1,35 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{ListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + let b = ListScalar::::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = ListScalar::::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!(BooleanArray::from_slice([true, false]), a.values().as_ref()); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/map.rs b/crates/polars/tests/it/arrow/scalar/map.rs new file mode 100644 index 000000000000..e9f0ede0784f --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/map.rs @@ -0,0 +1,66 @@ +use arrow::array::{BooleanArray, StructArray, Utf8Array}; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{MapScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key", ArrowDataType::Utf8, false), + Field::new("value", ArrowDataType::Boolean, true), + ]); + let kv_array1 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + let kv_array2 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k3")]).boxed(), + BooleanArray::from_slice([true, true]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array1))); + let b = MapScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = MapScalar::new(dt, Some(Box::new(kv_array2))); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key", ArrowDataType::Utf8, false), + Field::new("value", ArrowDataType::Boolean, true), + ]); + let kv_array = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array.clone()))); + + assert_eq!(kv_array, a.values().as_ref()); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/mod.rs b/crates/polars/tests/it/arrow/scalar/mod.rs new file mode 100644 index 000000000000..0c1ef990b829 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/mod.rs @@ -0,0 +1,16 @@ +mod binary; +mod boolean; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod utf8; + +// check that `PartialEq` can be derived +#[derive(PartialEq)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/scalar/null.rs b/crates/polars/tests/it/arrow/scalar/null.rs new file mode 100644 index 000000000000..3ceaf69f83b6 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/null.rs @@ -0,0 +1,19 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{NullScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = NullScalar::new(); + assert_eq!(a, a); +} + +#[test] +fn basics() { + let a = NullScalar::default(); + + assert_eq!(a.data_type(), &ArrowDataType::Null); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/primitive.rs b/crates/polars/tests/it/arrow/scalar/primitive.rs new file mode 100644 index 000000000000..954a80147833 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/primitive.rs @@ -0,0 +1,36 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{PrimitiveScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = PrimitiveScalar::from(Some(2i32)); + let b = PrimitiveScalar::::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = PrimitiveScalar::::from(Some(1i32)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = PrimitiveScalar::from(Some(2i32)); + + assert_eq!(a.value(), &Some(2i32)); + assert_eq!(a.data_type(), &ArrowDataType::Int32); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); + + let a = PrimitiveScalar::::from(None); + + assert_eq!(a.data_type(), &ArrowDataType::Int32); + assert!(!a.is_valid()); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/struct_.rs b/crates/polars/tests/it/arrow/scalar/struct_.rs new file mode 100644 index 000000000000..23461bb26568 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/struct_.rs @@ -0,0 +1,41 @@ +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{BooleanScalar, Scalar, StructScalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + let a = StructScalar::new( + dt.clone(), + Some(vec![ + Box::new(BooleanScalar::from(Some(true))) as Box + ]), + ); + let b = StructScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = StructScalar::new( + dt, + Some(vec![ + Box::new(BooleanScalar::from(Some(false))) as Box + ]), + ); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + + let values = vec![Box::new(BooleanScalar::from(Some(true))) as Box]; + + let a = StructScalar::new(dt.clone(), Some(values.clone())); + + assert_eq!(a.values(), &values); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/utf8.rs b/crates/polars/tests/it/arrow/scalar/utf8.rs new file mode 100644 index 000000000000..bd7c6449d89c --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/utf8.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{Scalar, Utf8Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = Utf8Scalar::::from(Some("a")); + let b = Utf8Scalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = Utf8Scalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = Utf8Scalar::::from(Some("a")); + + assert_eq!(a.value(), Some("a")); + assert_eq!(a.data_type(), &ArrowDataType::Utf8); + assert!(a.is_valid()); + + let a = Utf8Scalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &ArrowDataType::LargeUtf8); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 0bd00587fbe2..212de7960562 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -1,6 +1,6 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; #[cfg(feature = "dtype-categorical")] -use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; +use polars_core::{disable_string_cache, SINGLE_LOCK}; use super::*; diff --git a/crates/polars/tests/it/core/pivot.rs b/crates/polars/tests/it/core/pivot.rs index ea54aa02e16d..6f9c996b44cc 100644 --- a/crates/polars/tests/it/core/pivot.rs +++ b/crates/polars/tests/it/core/pivot.rs @@ -6,29 +6,47 @@ use polars_ops::pivot::{pivot, pivot_stable, PivotAgg}; #[cfg(feature = "dtype-date")] fn test_pivot_date_() -> PolarsResult<()> { let mut df = df![ - "A" => [1, 1, 1, 1, 1, 1, 1, 1], - "B" => [8, 2, 3, 6, 3, 6, 2, 2], - "C" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000] + "index" => [8, 2, 3, 6, 3, 6, 2, 2], + "values1" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000], + "values2" => [1, 1, 1, 1, 1, 1, 1, 1], ]?; - df.try_apply("C", |s| s.cast(&DataType::Date))?; + df.try_apply("values1", |s| s.cast(&DataType::Date))?; - let out = pivot(&df, ["A"], ["B"], ["C"], true, Some(PivotAgg::Count), None)?; + // Test with date as the `columns` input + let out = pivot( + &df, + ["index"], + ["values1"], + Some(["values2"]), + true, + Some(PivotAgg::Count), + None, + )?; let first = 1 as IdxSize; let expected = df![ - "B" => [8i32, 2, 3, 6], + "index" => [8i32, 2, 3, 6], "1972-09-27" => [first, 3, 2, 2] ]?; assert!(out.equals_missing(&expected)); - let mut out = pivot_stable(&df, ["C"], ["B"], ["A"], true, Some(PivotAgg::First), None)?; + // Test with date as the `values` input. + let mut out = pivot_stable( + &df, + ["index"], + ["values2"], + Some(["values1"]), + true, + Some(PivotAgg::First), + None, + )?; out.try_apply("1", |s| { let ca = s.date()?; Ok(ca.to_string("%Y-%d-%m")) })?; let expected = df![ - "B" => [8i32, 2, 3, 6], + "index" => [8i32, 2, 3, 6], "1" => ["1972-27-09", "1972-27-09", "1972-27-09", "1972-27-09"] ]?; assert!(out.equals_missing(&expected)); @@ -38,31 +56,31 @@ fn test_pivot_date_() -> PolarsResult<()> { #[test] fn test_pivot_old() { - let s0 = Series::new("foo", ["A", "A", "B", "B", "C"].as_ref()); - let s1 = Series::new("N", [1, 2, 2, 4, 2].as_ref()); - let s2 = Series::new("bar", ["k", "l", "m", "m", "l"].as_ref()); + let s0 = Series::new("index", ["A", "A", "B", "B", "C"].as_ref()); + let s2 = Series::new("columns", ["k", "l", "m", "m", "l"].as_ref()); + let s1 = Series::new("values", [1, 2, 2, 4, 2].as_ref()); let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Sum), None, ) .unwrap(); - assert_eq!(pvt.get_column_names(), &["foo", "k", "l", "m"]); + assert_eq!(pvt.get_column_names(), &["index", "k", "l", "m"]); assert_eq!( Vec::from(&pvt.column("m").unwrap().i32().unwrap().sort(false)), &[None, None, Some(6)] ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Min), None, @@ -74,9 +92,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Max), None, @@ -88,9 +106,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Mean), None, @@ -102,9 +120,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Count), None, @@ -120,46 +138,51 @@ fn test_pivot_old() { #[cfg(feature = "dtype-categorical")] fn test_pivot_categorical() -> PolarsResult<()> { let mut df = df![ - "A" => [1, 1, 1, 1, 1, 1, 1, 1], - "B" => [8, 2, 3, 6, 3, 6, 2, 2], - "C" => ["a", "b", "c", "a", "b", "c", "a", "b"] + "index" => [1, 1, 1, 1, 1, 1, 1, 1], + "columns" => ["a", "b", "c", "a", "b", "c", "a", "b"], + "values" => [8, 2, 3, 6, 3, 6, 2, 2], ]?; - df.try_apply("C", |s| { + df.try_apply("columns", |s| { s.cast(&DataType::Categorical(None, Default::default())) })?; - let out = pivot(&df, ["A"], ["B"], ["C"], true, Some(PivotAgg::Count), None)?; - assert_eq!(out.get_column_names(), &["B", "a", "b", "c"]); + let out = pivot( + &df, + ["index"], + ["columns"], + Some(["values"]), + true, + Some(PivotAgg::Count), + None, + )?; + assert_eq!(out.get_column_names(), &["index", "a", "b", "c"]); Ok(()) } #[test] fn test_pivot_new() -> PolarsResult<()> { - let df = df!["A"=> ["foo", "foo", "foo", "foo", "foo", - "bar", "bar", "bar", "bar"], - "B"=> ["one", "one", "one", "two", "two", - "one", "one", "two", "two"], - "C"=> ["small", "large", "large", "small", - "small", "large", "small", "small", "large"], - "breaky"=> ["jam", "egg", "egg", "egg", - "jam", "jam", "potato", "jam", "jam"], - "D"=> [1, 2, 2, 3, 3, 4, 5, 6, 7], - "E"=> [2, 4, 5, 5, 6, 6, 8, 9, 9] + let df = df![ + "index1"=> ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "index2"=> ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "cols1"=> ["small", "large", "large", "small", "small", "large", "small", "small", "large"], + "cols2"=> ["jam", "egg", "egg", "egg", "jam", "jam", "potato", "jam", "jam"], + "values1"=> [1, 2, 2, 3, 3, 4, 5, 6, 7], + "values2"=> [2, 4, 5, 5, 6, 6, 8, 9, 9] ]?; let out = (pivot_stable( &df, - ["D"], - ["A", "B"], - ["C"], + ["index1", "index2"], + ["cols1"], + Some(["values1"]), true, Some(PivotAgg::Sum), None, ))?; let expected = df![ - "A" => ["foo", "foo", "bar", "bar"], - "B" => ["one", "two", "one", "two"], + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], "large" => [Some(4), None, Some(4), Some(7)], "small" => [1, 6, 5, 6], ]?; @@ -167,21 +190,21 @@ fn test_pivot_new() -> PolarsResult<()> { let out = pivot_stable( &df, - ["D"], - ["A", "B"], - ["C", "breaky"], + ["index1", "index2"], + ["cols1", "cols2"], + Some(["values1"]), true, Some(PivotAgg::Sum), None, )?; let expected = df![ - "A" => ["foo", "foo", "bar", "bar"], - "B" => ["one", "two", "one", "two"], - "large" => [Some(4), None, Some(4), Some(7)], - "small" => [1, 6, 5, 6], - "egg" => [Some(4), Some(3), None, None], - "jam" => [1, 3, 4, 13], - "potato" => [None, None, Some(5), None] + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], + "{\"large\",\"egg\"}" => [Some(4), None, None, None], + "{\"large\",\"jam\"}" => [None, None, Some(4), Some(7)], + "{\"small\",\"egg\"}" => [None, Some(3), None, None], + "{\"small\",\"jam\"}" => [Some(1), Some(3), None, Some(6)], + "{\"small\",\"potato\"}" => [None, None, Some(5), None], ]?; assert!(out.equals_missing(&expected)); @@ -191,22 +214,22 @@ fn test_pivot_new() -> PolarsResult<()> { #[test] fn test_pivot_2() -> PolarsResult<()> { let df = df![ - "name"=> ["avg", "avg", "act", "test", "test"], - "err" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")], - "wght"=> [0.0, 0.1, 1.0, 0.4, 0.2] + "index" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")], + "columns"=> ["avg", "avg", "act", "test", "test"], + "values"=> [0.0, 0.1, 1.0, 0.4, 0.2] ]?; let out = pivot_stable( &df, - ["wght"], - ["err"], - ["name"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::First), None, )?; let expected = df![ - "err" => [Some("name1"), Some("name2"), None], + "index" => [Some("name1"), Some("name2"), None], "avg" => [Some(0.0), Some(0.1), None], "act" => [None, None, Some(1.)], "test" => [Some(0.4), Some(0.2), None], @@ -224,22 +247,22 @@ fn test_pivot_datetime() -> PolarsResult<()> { .and_hms_opt(12, 15, 0) .unwrap(); let df = df![ - "dt" => [dt, dt, dt, dt], - "key" => ["x", "x", "y", "y"], - "val" => [100, 50, 500, -80] + "index" => [dt, dt, dt, dt], + "columns" => ["x", "x", "y", "y"], + "values" => [100, 50, 500, -80] ]?; let out = pivot( &df, - ["val"], - ["dt"], - ["key"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Sum), None, )?; let expected = df![ - "dt" => [dt], + "index" => [dt], "x" => [150], "y" => [420] ]?; diff --git a/crates/polars/tests/it/io/avro/mod.rs b/crates/polars/tests/it/io/avro/mod.rs new file mode 100644 index 000000000000..e341bce42737 --- /dev/null +++ b/crates/polars/tests/it/io/avro/mod.rs @@ -0,0 +1,8 @@ +//! Read and write from and to Apache Avro + +mod read; +#[cfg(feature = "avro")] +mod read_async; +mod write; +#[cfg(feature = "avro")] +mod write_async; diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs new file mode 100644 index 000000000000..d15b805b19b2 --- /dev/null +++ b/crates/polars/tests/it/io/avro/read.rs @@ -0,0 +1,356 @@ +use apache_avro::types::{Record, Value}; +use apache_avro::{Codec, Days, Duration, Millis, Months, Schema as AvroSchema, Writer}; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::read::read_metadata; +use arrow::io::avro::read; +use polars_error::PolarsResult; + +pub(super) fn schema() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "c", "type": "int"}, + { + "name": "date", + "type": "int", + "logicalType": "date" + }, + {"name": "d", "type": "bytes"}, + {"name": "e", "type": "double"}, + {"name": "f", "type": "boolean"}, + {"name": "g", "type": ["null", "string"], "default": null}, + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": ["null", "int"], + "default": null + } + }}, + {"name": "i", "type": { + "type": "record", + "name": "bla", + "fields": [ + {"name": "e", "type": "double"} + ] + }}, + {"name": "nullable_struct", "type": [ + "null", { + "type": "record", + "name": "foo", + "fields": [ + {"name": "e", "type": "double"} + ] + }] + , "default": null + } + ] + } +"#; + + let schema = ArrowSchema::from(vec![ + Field::new("a", ArrowDataType::Int64, false), + Field::new("b", ArrowDataType::Utf8, false), + Field::new("c", ArrowDataType::Int32, false), + Field::new("date", ArrowDataType::Date32, false), + Field::new("d", ArrowDataType::Binary, false), + Field::new("e", ArrowDataType::Float64, false), + Field::new("f", ArrowDataType::Boolean, false), + Field::new("g", ArrowDataType::Utf8, true), + Field::new( + "h", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + false, + ), + Field::new( + "i", + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + false, + ), + Field::new( + "nullable_struct", + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + true, + ), + ]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data() -> Chunk> { + let data = vec![ + Some(vec![Some(1i32), None, Some(3)]), + Some(vec![Some(1i32), None, Some(3)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + + let columns = vec![ + Int64Array::from_slice([27, 47]).boxed(), + Utf8Array::::from_slice(["foo", "bar"]).boxed(), + Int32Array::from_slice([1, 1]).boxed(), + Int32Array::from_slice([1, 2]) + .to(ArrowDataType::Date32) + .boxed(), + BinaryArray::::from_slice([b"foo", b"bar"]).boxed(), + PrimitiveArray::::from_slice([1.0, 2.0]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + Utf8Array::::from([Some("foo"), None]).boxed(), + array.into_box(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + vec![PrimitiveArray::::from_slice([1.0, 2.0]).boxed()], + None, + ) + .boxed(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + vec![PrimitiveArray::::from_slice([1.0, 0.0]).boxed()], + Some([true, false].into()), + ) + .boxed(), + ]; + + Chunk::try_new(columns).unwrap() +} + +pub(super) fn write_avro(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + record.put("c", 1i32); + record.put("date", 1i32); + record.put("d", b"foo".as_ref()); + record.put("e", 1.0f64); + record.put("f", true); + record.put("g", Some("foo")); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(1.0f64))]), + ); + record.put( + "duration", + Value::Duration(Duration::new(Months::new(1), Days::new(1), Millis::new(1))), + ); + record.put( + "nullable_struct", + Value::Union( + 1, + Box::new(Value::Record(vec![( + "e".to_string(), + Value::Double(1.0f64), + )])), + ), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("b", "bar"); + record.put("a", 47i64); + record.put("c", 1i32); + record.put("date", 2i32); + record.put("d", b"bar".as_ref()); + record.put("e", 2.0f64); + record.put("f", false); + record.put("g", None::<&str>); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(2.0f64))]), + ); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put("nullable_struct", Value::Union(0, Box::new(Value::Null))); + writer.append(record)?; + writer.into_inner() +} + +pub(super) fn read_avro( + mut avro: &[u8], + projection: Option>, +) -> PolarsResult<(Chunk>, ArrowSchema)> { + let file = &mut avro; + + let metadata = read_metadata(file)?; + let schema = read::infer_schema(&metadata.record)?; + + let mut reader = read::Reader::new(file, metadata, schema.fields.clone(), projection.clone()); + + let schema = if let Some(projection) = projection { + let fields = schema + .fields + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect::>(); + ArrowSchema::from(fields) + } else { + schema + }; + + reader.next().unwrap().map(|x| (x, schema)) +} + +fn test(codec: Codec) -> PolarsResult<()> { + let avro = write_avro(codec).unwrap(); + let expected = data(); + let (_, expected_schema) = schema(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} + +#[test] +fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy) +} + +#[test] +fn test_projected() -> PolarsResult<()> { + let expected = data(); + let (_, expected_schema) = schema(); + + let avro = write_avro(Codec::Null).unwrap(); + + for i in 0..expected_schema.fields.len() { + let mut projection = vec![false; expected_schema.fields.len()]; + projection[i] = true; + + let expected = expected + .clone() + .into_arrays() + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + let expected = Chunk::new(expected); + + let expected_fields = expected_schema + .clone() + .fields + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect::>(); + let expected_schema = ArrowSchema::from(expected_fields); + + let (result, schema) = read_avro(&avro, Some(projection))?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + } + Ok(()) +} + +fn schema_list() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": "int" + } + }} + ] + } +"#; + + let schema = ArrowSchema::from(vec![Field::new( + "h", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + false, + )]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data_list() -> Chunk> { + let data = [Some(vec![Some(1i32), Some(2), Some(3)]), Some(vec![])]; + + let mut array = MutableListArray::>::new_from( + Default::default(), + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + 0, + ); + array.try_extend(data).unwrap(); + + let columns = vec![array.into_box()]; + + Chunk::try_new(columns).unwrap() +} + +pub(super) fn write_list(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema_list(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put( + "h", + Value::Array(vec![Value::Int(1), Value::Int(2), Value::Int(3)]), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("h", Value::Array(vec![])); + writer.append(record)?; + Ok(writer.into_inner().unwrap()) +} + +#[test] +fn test_list() -> PolarsResult<()> { + let avro = write_list(Codec::Null).unwrap(); + let expected = data_list(); + let (_, expected_schema) = schema_list(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/read_async.rs b/crates/polars/tests/it/io/avro/read_async.rs new file mode 100644 index 000000000000..d50fd7595c58 --- /dev/null +++ b/crates/polars/tests/it/io/avro/read_async.rs @@ -0,0 +1,42 @@ +use apache_avro::Codec; +use arrow::io::avro::avro_schema::read_async::{block_stream, read_metadata}; +use arrow::io::avro::read; +use futures::{pin_mut, StreamExt}; +use polars_error::PolarsResult; + +use super::read::{schema, write_avro}; + +async fn test(codec: Codec) -> PolarsResult<()> { + let avro_data = write_avro(codec).unwrap(); + let (_, expected_schema) = schema(); + + let mut reader = &mut &avro_data[..]; + + let metadata = read_metadata(&mut reader).await?; + let schema = read::infer_schema(&metadata.record)?; + + assert_eq!(schema, expected_schema); + + let blocks = block_stream(&mut reader, metadata.marker).await; + + pin_mut!(blocks); + while let Some(block) = blocks.next().await.transpose()? { + assert!(block.number_of_rows > 0 || block.data.is_empty()) + } + Ok(()) +} + +#[tokio::test] +async fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null).await +} + +#[tokio::test] +async fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate).await +} + +#[tokio::test] +async fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy).await +} diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs new file mode 100644 index 000000000000..dd8058aa52dc --- /dev/null +++ b/crates/polars/tests/it/io/avro/write.rs @@ -0,0 +1,372 @@ +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::file::{Block, CompressedBlock, Compression}; +use arrow::io::avro::avro_schema::write::{compress, write_block, write_metadata}; +use arrow::io::avro::write; +use avro_schema::schema::{Field as AvroField, Record, Schema as AvroSchema}; +use polars_error::PolarsResult; + +use super::read::read_avro; + +pub(super) fn schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("int64", ArrowDataType::Int64, false), + Field::new("int64 nullable", ArrowDataType::Int64, true), + Field::new("utf8", ArrowDataType::Utf8, false), + Field::new("utf8 nullable", ArrowDataType::Utf8, true), + Field::new("int32", ArrowDataType::Int32, false), + Field::new("int32 nullable", ArrowDataType::Int32, true), + Field::new("date", ArrowDataType::Date32, false), + Field::new("date nullable", ArrowDataType::Date32, true), + Field::new("binary", ArrowDataType::Binary, false), + Field::new("binary nullable", ArrowDataType::Binary, true), + Field::new("float32", ArrowDataType::Float32, false), + Field::new("float32 nullable", ArrowDataType::Float32, true), + Field::new("float64", ArrowDataType::Float64, false), + Field::new("float64 nullable", ArrowDataType::Float64, true), + Field::new("boolean", ArrowDataType::Boolean, false), + Field::new("boolean nullable", ArrowDataType::Boolean, true), + Field::new( + "list", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + false, + ), + Field::new( + "list nullable", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + true, + ), + ]) +} + +pub(super) fn data() -> Chunk> { + let list_dt = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); + let list_dt1 = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); + + let columns = vec![ + Box::new(Int64Array::from_slice([27, 47])) as Box, + Box::new(Int64Array::from([Some(27), None])), + Box::new(Utf8Array::::from_slice(["foo", "bar"])), + Box::new(Utf8Array::::from([Some("foo"), None])), + Box::new(Int32Array::from_slice([1, 1])), + Box::new(Int32Array::from([Some(1), None])), + Box::new(Int32Array::from_slice([1, 2]).to(ArrowDataType::Date32)), + Box::new(Int32Array::from([Some(1), None]).to(ArrowDataType::Date32)), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(BooleanArray::from_slice([true, false])), + Box::new(BooleanArray::from([Some(true), None])), + Box::new(ListArray::::new( + list_dt, + vec![0, 2, 5].try_into().unwrap(), + Box::new(PrimitiveArray::::from([ + None, + Some(1), + None, + Some(3), + Some(4), + ])), + None, + )), + Box::new(ListArray::::new( + list_dt1, + vec![0, 2, 2].try_into().unwrap(), + Box::new(PrimitiveArray::::from([None, Some(1)])), + Some([true, false].into()), + )), + ]; + + Chunk::new(columns) +} + +pub(super) fn serialize_to_block>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult { + let record = write::to_record(schema, "".to_string())?; + + let mut serializers = columns + .arrays() + .iter() + .map(|x| x.as_ref()) + .zip(record.fields.iter()) + .map(|(array, field)| write::new_serializer(array, &field.schema)) + .collect::>(); + let mut block = Block::new(columns.len(), vec![]); + + write::serialize(&mut serializers, &mut block); + + let mut compressed_block = CompressedBlock::default(); + + compress(&mut block, &mut compressed_block, compression)?; + + Ok(compressed_block) +} + +fn write_avro>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let avro_fields = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, avro_fields, compression)?; + + write_block(&mut file, &compressed_block)?; + + Ok(file) +} + +fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression)?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[test] +fn no_compression() -> PolarsResult<()> { + roundtrip(None) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn snappy() -> PolarsResult<()> { + roundtrip(Some(Compression::Snappy)) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn deflate() -> PolarsResult<()> { + roundtrip(Some(Compression::Deflate)) +} + +fn large_format_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("large_utf8", ArrowDataType::LargeUtf8, false), + Field::new("large_utf8_nullable", ArrowDataType::LargeUtf8, true), + Field::new("large_binary", ArrowDataType::LargeBinary, false), + Field::new("large_binary_nullable", ArrowDataType::LargeBinary, true), + ]) +} + +fn large_format_data() -> Chunk> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + Chunk::new(columns) +} + +fn large_format_expected_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("large_utf8", ArrowDataType::Utf8, false), + Field::new("large_utf8_nullable", ArrowDataType::Utf8, true), + Field::new("large_binary", ArrowDataType::Binary, false), + Field::new("large_binary_nullable", ArrowDataType::Binary, true), + ]) +} + +fn large_format_expected_data() -> Chunk> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + Chunk::new(columns) +} + +#[test] +fn check_large_format() -> PolarsResult<()> { + let write_schema = large_format_schema(); + let write_data = large_format_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schame) = read_avro(&data, None)?; + + let expected_schema = large_format_expected_schema(); + assert_eq!(read_schame, expected_schema); + + let expected_data = large_format_expected_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} + +fn struct_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new( + "struct", + ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]), + false, + ), + Field::new( + "struct nullable", + ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]), + true, + ), + ]) +} + +fn struct_data() -> Chunk> { + let struct_dt = ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]); + + Chunk::new(vec![ + Box::new(StructArray::new( + struct_dt.clone(), + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + None, + )), + Box::new(StructArray::new( + struct_dt, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + Some([true, false].into()), + )), + ]) +} + +fn avro_record() -> Record { + Record { + name: "".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "struct".to_string(), + doc: None, + schema: AvroSchema::Record(Record { + name: "r1".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "struct nullable".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Record(Record { + name: "r2".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + } +} + +#[test] +fn avro_record_schema() -> PolarsResult<()> { + let arrow_schema = struct_schema(); + let record = write::to_record(&arrow_schema, "".to_string())?; + assert_eq!(record, avro_record()); + Ok(()) +} + +#[test] +fn struct_() -> PolarsResult<()> { + let write_schema = struct_schema(); + let write_data = struct_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schema) = read_avro(&data, None)?; + + let expected_schema = struct_schema(); + assert_eq!(read_schema, expected_schema); + + let expected_data = struct_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/write_async.rs b/crates/polars/tests/it/io/avro/write_async.rs new file mode 100644 index 000000000000..7c04873af64c --- /dev/null +++ b/crates/polars/tests/it/io/avro/write_async.rs @@ -0,0 +1,48 @@ +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::write; +use avro_schema::file::Compression; +use avro_schema::write_async::{write_block, write_metadata}; +use polars_error::PolarsResult; + +use super::read::read_avro; +use super::write::{data, schema, serialize_to_block}; + +async fn write_avro>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + // usually done on a different thread pool + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let record = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, record, compression).await?; + + write_block(&mut file, &compressed_block).await?; + + Ok(file) +} + +async fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression).await?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[tokio::test] +async fn no_compression() -> PolarsResult<()> { + roundtrip(None).await +} diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs index 69e11a66e302..c855ad45f4c0 100644 --- a/crates/polars/tests/it/io/csv.rs +++ b/crates/polars/tests/it/io/csv.rs @@ -184,8 +184,8 @@ fn test_projection() -> PolarsResult<()> { fn test_missing_data() { // missing data should not lead to parser error. let csv = r#"column_1,column_2,column_3 - 1,2,3 - 1,,3 +1,2,3 +1,,3 "#; let file = Cursor::new(csv); @@ -982,7 +982,7 @@ fn test_empty_string_cols() -> PolarsResult<()> { let s = df.column("column_1")?; let ca = s.str()?; assert_eq!( - ca.into_iter().collect::>(), + ca.iter().collect::>(), &[None, Some("abc"), None, Some("xyz")] ); @@ -1120,7 +1120,7 @@ fn test_try_parse_dates() -> PolarsResult<()> { 1742-03-21 1743-06-16 1730-07-22 -'' + 1739-03-16 "; let file = Cursor::new(csv); @@ -1131,21 +1131,6 @@ fn test_try_parse_dates() -> PolarsResult<()> { Ok(()) } -#[test] -fn test_whitespace_skipping() -> PolarsResult<()> { - let csv = "a,b - 12, 1435"; - let file = Cursor::new(csv); - let out = CsvReader::new(file).finish()?; - let expected = df![ - "a" => [12i64], - "b" => [1435i64], - ]?; - assert!(out.equals(&expected)); - - Ok(()) -} - #[test] fn test_try_parse_dates_3380() -> PolarsResult<()> { let csv = "lat;lon;validdate;t_2m:C;precip_1h:mm @@ -1171,6 +1156,6 @@ fn test_leading_whitespace_with_quote() -> PolarsResult<()> { let col_1 = df.column("ABC").unwrap(); let col_2 = df.column("DEF").unwrap(); assert_eq!(col_1.get(0)?, AnyValue::Float64(24.5)); - assert_eq!(col_2.get(0)?, AnyValue::Float64(4.1)); + assert_eq!(col_2.get(0)?, AnyValue::String(" 4.1")); Ok(()) } diff --git a/crates/polars/tests/it/io/mod.rs b/crates/polars/tests/it/io/mod.rs index 7e384e214d4c..742056fc38ac 100644 --- a/crates/polars/tests/it/io/mod.rs +++ b/crates/polars/tests/it/io/mod.rs @@ -6,6 +6,9 @@ mod json; #[cfg(feature = "parquet")] mod parquet; +#[cfg(feature = "avro")] +mod avro; + #[cfg(feature = "ipc_streaming")] mod ipc_stream; diff --git a/crates/polars/tests/it/io/parquet.rs b/crates/polars/tests/it/io/parquet.rs deleted file mode 100644 index ad5349ced151..000000000000 --- a/crates/polars/tests/it/io/parquet.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::io::Cursor; - -use polars::prelude::*; - -#[test] -fn test_vstack_empty_3220() -> PolarsResult<()> { - let df1 = df! { - "a" => ["1", "2"], - "b" => [1, 2] - }?; - let empty_df = df1.head(Some(0)); - let mut stacked = df1.clone(); - stacked.vstack_mut(&empty_df)?; - stacked.vstack_mut(&df1)?; - let mut buf = Cursor::new(Vec::new()); - ParquetWriter::new(&mut buf).finish(&mut stacked)?; - let read_df = ParquetReader::new(buf).finish()?; - assert!(stacked.equals(&read_df)); - Ok(()) -} diff --git a/crates/polars/tests/it/io/parquet/arrow/integration.rs b/crates/polars/tests/it/io/parquet/arrow/integration.rs new file mode 100644 index 000000000000..7f84c433b0d5 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/integration.rs @@ -0,0 +1,41 @@ +use arrow2::error::Result; + +use super::{integration_read, integration_write}; +use crate::io::ipc::read_gzip_json; + +fn test_file(version: &str, file_name: &str) -> Result<()> { + let (schema, _, batches) = read_gzip_json(version, file_name)?; + + // empty batches are not written/read from parquet and can be ignored + let batches = batches + .into_iter() + .filter(|x| !x.is_empty()) + .collect::>(); + + let data = integration_write(&schema, &batches)?; + + let (read_schema, read_batches) = integration_read(&data, None)?; + + assert_eq!(schema, read_schema); + assert_eq!(batches, read_batches); + + Ok(()) +} + +#[test] +fn roundtrip_100_primitive() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive")?; + test_file("1.0.0-bigendian", "generated_primitive") +} + +#[test] +fn roundtrip_100_dict() -> Result<()> { + test_file("1.0.0-littleendian", "generated_dictionary")?; + test_file("1.0.0-bigendian", "generated_dictionary") +} + +#[test] +fn roundtrip_100_extension() -> Result<()> { + test_file("1.0.0-littleendian", "generated_extension")?; + test_file("1.0.0-bigendian", "generated_extension") +} diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs new file mode 100644 index 000000000000..d7832a6567ea --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -0,0 +1,1661 @@ +use std::io::{Cursor, Read, Seek}; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::legacy::prelude::LargeListArray; +use arrow::types::{i256, NativeType}; +use ethnum::AsI256; +use polars_error::PolarsResult; +use polars_parquet::read as p_read; +use polars_parquet::read::statistics::*; +use polars_parquet::write::*; + +#[cfg(feature = "io_json_integration")] +mod integration; +mod read; +mod read_indexes; +mod write; + +#[cfg(feature = "io_parquet_sample_test")] +mod sample_tests; + +type ArrayStats = (Box, Statistics); + +fn new_struct( + arrays: Vec>, + names: Vec, + validity: Option, +) -> StructArray { + let fields = names + .into_iter() + .zip(arrays.iter()) + .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .collect(); + StructArray::new(ArrowDataType::Struct(fields), arrays, validity) +} + +pub fn read_column(mut reader: R, column: &str) -> PolarsResult { + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + let row_group = &metadata.row_groups[0]; + + // verify that we can read indexes + if p_read::indexes::has_indexes(row_group) { + let _indexes = p_read::indexes::read_filtered_pages( + &mut reader, + row_group, + &schema.fields, + |_, _| vec![], + )?; + } + + let schema = schema.filter(|_, f| f.name == column); + + let field = &schema.fields[0]; + + let statistics = deserialize(field, row_group)?; + + let mut reader = p_read::FileReader::new(reader, metadata.row_groups, schema, None, None, None); + + let array = reader.next().unwrap()?.into_arrays().pop().unwrap(); + + Ok((array, statistics)) +} + +pub fn pyarrow_nested_edge(column: &str) -> Box { + match column { + "simple" => { + // [[0, 1]] + let data = [Some(vec![Some(0), Some(1)])]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "null" => { + // [None] + let data = [None::>>]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "empty" => { + // [None] + let data: [Option>>; 0] = []; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => { + // [ + // {"f1": ["a", "b", None, "c"]} + // ] + let a = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + ArrowDataType::Utf8View, + true, + ))), + vec![0, 4].try_into().unwrap(), + Utf8ViewArray::from_slice([Some("a"), Some("b"), None, Some("c")]).boxed(), + None, + ); + StructArray::new( + ArrowDataType::Struct(vec![Field::new("f1", a.data_type().clone(), true)]), + vec![a.boxed()], + None, + ) + .boxed() + }, + "list_struct_list_nullable" => { + let values = pyarrow_nested_edge("struct_list_nullable"); + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + values.data_type().clone(), + true, + ))), + vec![0, 1].try_into().unwrap(), + values, + None, + ) + .boxed() + }, + _ => todo!(), + } +} + +pub fn pyarrow_nested_nullable(column: &str) -> Box { + let i64_values = &[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]; + let offsets = vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(); + + let values = match column { + "list_int64" => { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(i64_values).boxed() + }, + "list_int64_required" | "list_int64_optional_required" | "list_int64_required_required" => { + // [[0, 1], None, [2, 0, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + Some(0), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed() + }, + "list_int16" => PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed(), + "list_bool" => BooleanArray::from(&[ + Some(false), + Some(true), + Some(true), + None, + Some(false), + Some(true), + Some(false), + Some(true), + Some(false), + Some(false), + Some(false), + Some(true), + ]) + .boxed(), + /* + string = [ + ["Hello", "bbb"], + None, + ["aa", None, ""], + ["bbb", "aa", "ccc"], + [], + ["abc", "bbb", "bbb"], + None, + [""], + ] + */ + "list_utf8" => Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + None, + Some("".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + Some("ccc".to_string()), + Some("abc".to_string()), + Some("bbb".to_string()), + Some("bbb".to_string()), + Some("".to_string()), + ]) + .boxed(), + "list_large_binary" => Box::new(BinaryArray::::from([ + Some(b"Hello".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + None, + Some(b"".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + Some(b"ccc".to_vec()), + Some(b"abc".to_vec()), + Some(b"bbb".to_vec()), + Some(b"bbb".to_vec()), + Some(b"".to_vec()), + ])), + "list_decimal" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(ArrowDataType::Decimal(9, 0))) + }, + "list_decimal256" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| i256(x.as_i256()))) + .collect::>(); + let array = PrimitiveArray::::from(values).to(ArrowDataType::Decimal256(9, 0)); + Box::new(array) + }, + "list_nested_i64" + | "list_nested_inner_required_i64" + | "list_nested_inner_required_required_i64" => { + Box::new(NullArray::new(ArrowDataType::Null, 1)) + }, + "struct_list_nullable" => pyarrow_nested_nullable("list_utf8"), + "list_struct_nullable" => { + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + None, + Some("b"), + // + None, + None, + None, + // + Some("d"), + Some("d"), + Some("d"), + // + Some("e"), + ]) + .boxed(); + new_struct( + vec![array], + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + "list_struct_list_nullable" => { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + Some("b"), + // + Some("d"), + None, + Some("c"), + Some("d"), + ]) + .boxed(); + + let array = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + array, + Some( + [ + true, true, true, false, true, false, false, false, true, true, true, true, + ] + .into(), + ), + ) + .boxed(); + + new_struct( + vec![array], + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + other => unreachable!("{}", other), + }; + + match column { + "list_int64_required_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data_type = + ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, false))); + ListArray::::new(data_type, offsets, values, None).boxed() + }, + "list_int64_optional_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data_type = + ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, true))); + ListArray::::new(data_type, offsets, values, None).boxed() + }, + "list_nested_i64" => { + // [[0, 1]], None, [[2, None], [3]], [[4, 5], [6]], [], [[7], None, [9]], [[], [None], None], [[10]] + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), None]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + Some(vec![Some(vec![]), Some(vec![None]), None]), + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![ + Some(vec![Some(7)]), + Some(vec![Some(8)]), + Some(vec![Some(9)]), + ]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => new_struct(vec![values], vec!["a".to_string()], None).boxed(), + _ => { + let field = match column { + "list_int64" => Field::new("item", ArrowDataType::Int64, true), + "list_int64_required" => Field::new("item", ArrowDataType::Int64, false), + "list_int16" => Field::new("item", ArrowDataType::Int16, true), + "list_bool" => Field::new("item", ArrowDataType::Boolean, true), + "list_utf8" => Field::new("item", ArrowDataType::Utf8View, true), + "list_large_binary" => Field::new("item", ArrowDataType::LargeBinary, true), + "list_decimal" => Field::new("item", ArrowDataType::Decimal(9, 0), true), + "list_decimal256" => Field::new("item", ArrowDataType::Decimal256(9, 0), true), + "list_struct_nullable" => Field::new("item", values.data_type().clone(), true), + "list_struct_list_nullable" => Field::new("item", values.data_type().clone(), true), + other => unreachable!("{}", other), + }; + + let validity = Some(Bitmap::from([ + true, false, true, true, true, true, false, true, + ])); + // [0, 2, 2, 5, 8, 8, 11, 11, 12] + // [[a1, a2], None, [a3, a4, a5], [a6, a7, a8], [], [a9, a10, a11], None, [a12]] + let data_type = ArrowDataType::LargeList(Box::new(field)); + ListArray::::new(data_type, offsets, values, validity).boxed() + }, + } +} + +pub fn pyarrow_nullable(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + let u32_values = &[ + Some(0), + Some(1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "float64" => Box::new(PrimitiveArray::::from(&[ + Some(0.0), + Some(1.0), + None, + Some(3.0), + None, + Some(5.0), + Some(6.0), + Some(7.0), + None, + Some(9.0), + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + None, + Some("aa".to_string()), + Some("".to_string()), + None, + Some("abc".to_string()), + None, + None, + Some("def".to_string()), + Some("aaa".to_string()), + ])), + "bool" => Box::new(BooleanArray::from([ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ])), + "timestamp_ms" => Box::new( + PrimitiveArray::::from_iter(u32_values.iter().map(|x| x.map(|x| x as i64))) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + "uint32" => Box::new(PrimitiveArray::::from(u32_values)), + "int32_dict" => { + let keys = PrimitiveArray::::from([Some(0), Some(1), None, Some(1)]); + let values = Box::new(PrimitiveArray::::from_slice([10, 200])); + Box::new(DictionaryArray::try_from_keys(keys, values).unwrap()) + }, + "timestamp_us" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + "timestamp_s" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + "timestamp_s_utc" => Box::new(PrimitiveArray::::from(i64_values).to( + ArrowDataType::Timestamp(TimeUnit::Second, Some("UTC".to_string())), + )), + _ => unreachable!(), + } +} + +pub fn pyarrow_nullable_statistics(column: &str) -> Statistics { + match column { + "int64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Int64Array::from_slice([-256])), + max_value: Box::new(Int64Array::from_slice([9])), + }, + "float64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Float64Array::from_slice([0.0])), + max_value: Box::new(Float64Array::from_slice([9.0])), + }, + "string" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(4)]).boxed(), + min_value: Box::new(Utf8ViewArray::from_slice([Some("")])), + max_value: Box::new(Utf8ViewArray::from_slice([Some("def")])), + }, + "bool" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(4)]).boxed(), + min_value: Box::new(BooleanArray::from_slice([false])), + max_value: Box::new(BooleanArray::from_slice([true])), + }, + "timestamp_ms" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([0]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + }, + "uint32" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(UInt32Array::from_slice([0])), + max_value: Box::new(UInt32Array::from_slice([9])), + }, + "int32_dict" => { + let new_dict = |array: Box| -> Box { + Box::new(DictionaryArray::try_from_keys(vec![Some(0)].into(), array).unwrap()) + }; + + Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(1)]).boxed(), + min_value: new_dict(Box::new(Int32Array::from_slice([10]))), + max_value: new_dict(Box::new(Int32Array::from_slice([200]))), + } + }, + "timestamp_us" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([-256]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + }, + "timestamp_s" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([-256]).to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]).to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + }, + "timestamp_s_utc" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Int64Array::from_slice([-256]).to(ArrowDataType::Timestamp( + TimeUnit::Second, + Some("UTC".to_string()), + ))), + max_value: Box::new(Int64Array::from_slice([9]).to(ArrowDataType::Timestamp( + TimeUnit::Second, + Some("UTC".to_string()), + ))), + }, + _ => unreachable!(), + } +} + +// these values match the values in `integration` +pub fn pyarrow_required(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "bool" => Box::new(BooleanArray::from_slice([ + true, true, false, false, false, true, true, true, true, true, + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello"), + Some("bbb"), + Some("aa"), + Some(""), + Some("bbb"), + Some("abc"), + Some("bbb"), + Some("bbb"), + Some("def"), + Some("aaa"), + ])), + _ => unreachable!(), + } +} + +pub fn pyarrow_required_statistics(column: &str) -> Statistics { + let mut s = pyarrow_nullable_statistics(column); + s.null_count = UInt64Array::from([Some(0)]).boxed(); + s +} + +pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { + let new_list = |array: Box, nullable: bool| { + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + nullable, + ))), + vec![0, array.len() as i64].try_into().unwrap(), + array, + None, + ) + }; + + match column { + "list_int16" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int16Array::from_slice([0])), true).boxed(), + max_value: new_list(Box::new(Int16Array::from_slice([10])), true).boxed(), + }, + "list_bool" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(BooleanArray::from_slice([false])), true).boxed(), + max_value: new_list(Box::new(BooleanArray::from_slice([true])), true).boxed(), + }, + "list_utf8" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Utf8ViewArray::from_slice([Some("")])), true).boxed(), + max_value: new_list(Box::new(Utf8ViewArray::from_slice([Some("ccc")])), true).boxed(), + }, + "list_large_binary" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(BinaryArray::::from_slice([b""])), true).boxed(), + max_value: new_list(Box::new(BinaryArray::::from_slice([b"ccc"])), true).boxed(), + }, + "list_decimal" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list( + Box::new(Int128Array::from_slice([0]).to(ArrowDataType::Decimal(9, 0))), + true, + ) + .boxed(), + max_value: new_list( + Box::new(Int128Array::from_slice([10]).to(ArrowDataType::Decimal(9, 0))), + true, + ) + .boxed(), + }, + "list_decimal256" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list( + Box::new( + Int256Array::from_slice([i256(0.as_i256())]) + .to(ArrowDataType::Decimal256(9, 0)), + ), + true, + ) + .boxed(), + max_value: new_list( + Box::new( + Int256Array::from_slice([i256(10.as_i256())]) + .to(ArrowDataType::Decimal256(9, 0)), + ), + true, + ) + .boxed(), + }, + "list_int64" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + }, + "list_int64_required" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), false).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), false).boxed(), + }, + "list_int64_required_required" | "list_int64_optional_required" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), false).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed(), false).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), false).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), false).boxed(), + }, + "list_nested_i64" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(2)]).boxed(), true).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_nested_inner_required_required_i64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(0)]).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_nested_inner_required_i64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(0)]).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_struct_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![UInt64Array::from([None]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + null_count: new_list( + new_struct( + vec![UInt64Array::from([Some(4)]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + min_value: new_list( + new_struct( + vec![Utf8ViewArray::from_slice([Some("a")]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_struct( + vec![Utf8ViewArray::from_slice([Some("e")]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + }, + "list_struct_list_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + null_count: new_list( + new_struct( + vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + min_value: new_list( + new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("a")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("d")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + }, + "struct_list_nullable" => Statistics { + distinct_count: new_struct( + vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + null_count: new_struct( + vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + min_value: new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + max_value: new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("ccc")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + }, + other => todo!("{}", other), + } +} + +pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { + let new_list = |array: Box| { + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + vec![0, array.len() as i64].try_into().unwrap(), + array, + None, + ) + }; + + let new_struct = |arrays: Vec>, names: Vec| { + let fields = names + .into_iter() + .zip(arrays.iter()) + .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .collect(); + StructArray::new(ArrowDataType::Struct(fields), arrays, None) + }; + + let names = vec!["f1".to_string()]; + + match column { + "simple" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed()).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed()).boxed(), + min_value: new_list(Box::new(Int64Array::from([Some(0)]))).boxed(), + max_value: new_list(Box::new(Int64Array::from([Some(1)]))).boxed(), + }, + "null" | "empty" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed()).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed()).boxed(), + min_value: new_list(Box::new(Int64Array::from([None]))).boxed(), + max_value: new_list(Box::new(Int64Array::from([None]))).boxed(), + }, + "struct_list_nullable" => Statistics { + distinct_count: new_struct( + vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + names.clone(), + ) + .boxed(), + min_value: Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + names.clone(), + )), + max_value: Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + names, + )), + }, + "list_struct_list_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + names.clone(), + ) + .boxed(), + ) + .boxed(), + null_count: new_list( + new_struct( + vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + names.clone(), + ) + .boxed(), + ) + .boxed(), + min_value: new_list(Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + names.clone(), + ))) + .boxed(), + max_value: new_list(Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + names, + ))) + .boxed(), + }, + _ => unreachable!(), + } +} + +pub fn pyarrow_struct(column: &str) -> Box { + let boolean = [ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ]; + let boolean = BooleanArray::from(boolean).boxed(); + + let string = [ + Some("Hello"), + None, + Some("aa"), + Some(""), + None, + Some("abc"), + None, + None, + Some("def"), + Some("aaa"), + ]; + let string = Utf8ViewArray::from_slice(string).boxed(); + + let mask = [true, true, false, true, true, true, true, true, true, true]; + + let fields = vec![ + Field::new("f1", ArrowDataType::Utf8View, true), + Field::new("f2", ArrowDataType::Boolean, true), + ]; + match column { + "struct" => { + StructArray::new(ArrowDataType::Struct(fields), vec![string, boolean], None).boxed() + }, + "struct_nullable" => { + let values = vec![string, boolean]; + StructArray::new(ArrowDataType::Struct(fields), values, Some(mask.into())).boxed() + }, + "struct_struct" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1", ArrowDataType::Struct(fields), true), + Field::new("f2", ArrowDataType::Boolean, true), + ]), + vec![struct_, boolean], + None, + )) + }, + "struct_struct_nullable" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1", ArrowDataType::Struct(fields), true), + Field::new("f2", ArrowDataType::Boolean, true), + ]), + vec![struct_, boolean], + Some(mask.into()), + )) + }, + _ => todo!(), + } +} + +pub fn pyarrow_struct_statistics(column: &str) -> Statistics { + let new_struct = + |arrays: Vec>, names: Vec| new_struct(arrays, names, None); + + let names = vec!["f1".to_string(), "f2".to_string()]; + + match column { + "struct" | "struct_nullable" => Statistics { + distinct_count: new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + Box::new(UInt64Array::from([Some(4)])), + Box::new(UInt64Array::from([Some(4)])), + ], + names.clone(), + ) + .boxed(), + min_value: Box::new(new_struct( + vec![ + Box::new(Utf8ViewArray::from_slice([Some("")])), + Box::new(BooleanArray::from_slice([false])), + ], + names.clone(), + )), + max_value: Box::new(new_struct( + vec![ + Box::new(Utf8ViewArray::from_slice([Some("def")])), + Box::new(BooleanArray::from_slice([true])), + ], + names, + )), + }, + "struct_struct" => Statistics { + distinct_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([None]).boxed(), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([Some(4)])), + Box::new(UInt64Array::from([Some(4)])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([Some(4)]).boxed(), + ], + names.clone(), + ) + .boxed(), + min_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("")]).boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + max_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("def")]).boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names, + ) + .boxed(), + }, + "struct_struct_nullable" => Statistics { + distinct_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([None]).boxed(), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([Some(5)])), + Box::new(UInt64Array::from([Some(5)])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([Some(5)]).boxed(), + ], + names.clone(), + ) + .boxed(), + min_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("")]).boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + max_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("def")]).boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names, + ) + .boxed(), + }, + _ => todo!(), + } +} + +fn integration_write( + schema: &ArrowSchema, + chunks: &[Chunk>], +) -> PolarsResult> { + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let encodings = schema + .fields + .iter() + .map(|f| { + transverse(&f.data_type, |x| { + if let ArrowDataType::Dictionary(..) = x { + Encoding::RleDictionary + } else { + Encoding::Plain + } + }) + }) + .collect(); + + let row_groups = + RowGroupIterator::try_new(chunks.iter().cloned().map(Ok), schema, options, encodings)?; + + let writer = Cursor::new(vec![]); + + let mut writer = FileWriter::try_new(writer, schema.clone(), options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + Ok(writer.into_inner().into_inner()) +} + +type IntegrationRead = (ArrowSchema, Vec>>); + +fn integration_read(data: &[u8], limit: Option) -> PolarsResult { + let mut reader = Cursor::new(data); + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + for (field, row_group) in schema.fields.iter().zip(metadata.row_groups.iter()) { + let mut _statistics = deserialize(field, row_group)?; + } + + let reader = p_read::FileReader::new( + Cursor::new(data), + metadata.row_groups, + schema.clone(), + None, + limit, + None, + ); + + let batches = reader.collect::>>()?; + + Ok((schema, batches)) +} + +fn generic_data() -> PolarsResult<(ArrowSchema, Chunk>)> { + let array1 = PrimitiveArray::::from([Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Second)); + let array2 = Utf8ViewArray::from_slice([Some("a"), None, Some("bb")]); + + let indices = PrimitiveArray::from_values((0..3u64).map(|x| x % 2)); + let values = PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(); + let array3 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let array4 = BinaryViewArray::from_slice([Some(b"ab"), Some(b"aa"), Some(b"ac")]); + + let values = PrimitiveArray::from_slice([1i16, 3]).boxed(); + let array6 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1i64, 3]) + .to(ArrowDataType::Timestamp( + TimeUnit::Millisecond, + Some("UTC".to_string()), + )) + .boxed(); + let array7 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1.0f64, 3.0]).boxed(); + let array8 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u8, 3]).boxed(); + let array9 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u16, 3]).boxed(); + let array10 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u32, 3]).boxed(); + let array11 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u64, 3]).boxed(); + let array12 = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let array13 = PrimitiveArray::::from_slice([1, 2, 3]) + .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); + + let schema = ArrowSchema::from(vec![ + Field::new("a1", array1.data_type().clone(), true), + Field::new("a2", array2.data_type().clone(), true), + Field::new("a3", array3.data_type().clone(), true), + Field::new("a4", array4.data_type().clone(), true), + Field::new("a6", array6.data_type().clone(), true), + Field::new("a7", array7.data_type().clone(), true), + Field::new("a8", array8.data_type().clone(), true), + Field::new("a9", array9.data_type().clone(), true), + Field::new("a10", array10.data_type().clone(), true), + Field::new("a11", array11.data_type().clone(), true), + Field::new("a12", array12.data_type().clone(), true), + Field::new("a13", array13.data_type().clone(), true), + ]); + let chunk = Chunk::try_new(vec![ + array1.boxed(), + array2.boxed(), + array3.boxed(), + array4.boxed(), + array6.boxed(), + array7.boxed(), + array8.boxed(), + array9.boxed(), + array10.boxed(), + array11.boxed(), + array12.boxed(), + array13.boxed(), + ])?; + + Ok((schema, chunk)) +} + +fn assert_roundtrip( + schema: ArrowSchema, + chunk: Chunk>, + limit: Option, +) -> PolarsResult<()> { + let r = integration_write(&schema, &[chunk.clone()])?; + + let (new_schema, new_chunks) = integration_read(&r, limit)?; + + let expected = if let Some(limit) = limit { + let expected = chunk + .into_arrays() + .into_iter() + .map(|x| x.sliced(0, limit)) + .collect::>(); + Chunk::new(expected) + } else { + chunk + }; + + assert_eq!(new_schema, schema); + assert_eq!(new_chunks, vec![expected]); + Ok(()) +} + +/// Tests that when arrow-specific types (Duration and LargeUtf8) are written to parquet, we can roundtrip its +/// logical types. +#[test] +fn arrow_type() -> PolarsResult<()> { + let (schema, chunk) = generic_data()?; + assert_roundtrip(schema, chunk, None) +} + +fn data>( + mut iter: I, + inner_is_nullable: bool, +) -> Box { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data = vec![ + Some(vec![Some(iter.next().unwrap()), Some(iter.next().unwrap())]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![Some(iter.next().unwrap())]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + inner_is_nullable, + ); + array.try_extend(data).unwrap(); + array.into_box() +} + +fn assert_array_roundtrip( + is_nullable: bool, + array: Box, + limit: Option, +) -> PolarsResult<()> { + let schema = ArrowSchema::from(vec![Field::new( + "a1", + array.data_type().clone(), + is_nullable, + )]); + let chunk = Chunk::try_new(vec![array])?; + + assert_roundtrip(schema, chunk, limit) +} + +fn test_list_array_required_required(limit: Option) -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12i8, false), limit)?; + assert_array_roundtrip(false, data(0..12i16, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12u8, false), limit)?; + assert_array_roundtrip(false, data(0..12u16, false), limit)?; + assert_array_roundtrip(false, data(0..12u32, false), limit)?; + assert_array_roundtrip(false, data(0..12u64, false), limit)?; + assert_array_roundtrip(false, data((0..12).map(|x| (x as f32) * 1.0), false), limit)?; + assert_array_roundtrip( + false, + data((0..12).map(|x| (x as f64) * 1.0f64), false), + limit, + ) +} + +#[test] +fn list_array_required_required() -> PolarsResult<()> { + test_list_array_required_required(None) +} + +#[test] +fn list_array_optional_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, true), None) +} + +#[test] +fn list_array_required_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, false), None) +} + +#[test] +fn list_array_optional_required() -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12, true), None) +} + +#[test] +fn list_slice() -> PolarsResult<()> { + let data = vec![ + Some(vec![None, Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5), Some(6)]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + true, + ); + array.try_extend(data).unwrap(); + let a: ListArray = array.into(); + let a = a.sliced(2, 1); + assert_array_roundtrip(false, a.boxed(), None) +} + +#[test] +fn struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("struct_list_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("list_struct_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_int_nullable() -> PolarsResult<()> { + let data = vec![ + Some(vec![Some(1)]), + None, + Some(vec![None, Some(2)]), + Some(vec![]), + Some(vec![Some(3)]), + None, + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + true, + ); + array.try_extend(data).unwrap(); + assert_array_roundtrip(true, array.into_box(), None) +} + +#[test] +fn limit() -> PolarsResult<()> { + let (schema, chunk) = generic_data()?; + assert_roundtrip(schema, chunk, Some(2)) +} + +#[test] +fn limit_list() -> PolarsResult<()> { + test_list_array_required_required(Some(2)) +} + +fn nested_dict_data( + data_type: ArrowDataType, +) -> PolarsResult<(ArrowSchema, Chunk>)> { + let values = match data_type { + ArrowDataType::Float32 => PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(), + ArrowDataType::Utf8View => Utf8ViewArray::from_slice([Some("a"), Some("b")]).boxed(), + _ => unreachable!(), + }; + + let indices = PrimitiveArray::from_values((0..3u64).map(|x| x % 2)); + let values = DictionaryArray::try_from_keys(indices, values).unwrap(); + let values = LargeListArray::try_new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + values.data_type().clone(), + false, + ))), + vec![0i64, 0, 0, 2, 3].try_into().unwrap(), + values.boxed(), + Some([true, false, true, true].into()), + )?; + + let schema = ArrowSchema::from(vec![Field::new("c1", values.data_type().clone(), true)]); + let chunk = Chunk::try_new(vec![values.boxed()])?; + + Ok((schema, chunk)) +} + +#[test] +fn nested_dict() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Float32)?; + + assert_roundtrip(schema, chunk, None) +} + +#[test] +fn nested_dict_utf8() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Utf8View)?; + + assert_roundtrip(schema, chunk, None) +} + +#[test] +fn nested_dict_limit() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Float32)?; + + assert_roundtrip(schema, chunk, Some(2)) +} + +#[test] +fn filter_chunk() -> PolarsResult<()> { + let chunk1 = Chunk::new(vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); + let chunk2 = Chunk::new(vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); + let schema = ArrowSchema::from(vec![Field::new("c1", ArrowDataType::Int16, true)]); + + let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; + + let mut reader = Cursor::new(r); + + let metadata = p_read::read_metadata(&mut reader)?; + + let new_schema = p_read::infer_schema(&metadata)?; + assert_eq!(new_schema, schema); + + // select chunk 1 + let row_groups = metadata + .row_groups + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 0) + .map(|(_, row_group)| row_group) + .collect(); + + let reader = p_read::FileReader::new(reader, row_groups, schema, None, None, None); + + let new_chunks = reader.collect::>>()?; + + assert_eq!(new_chunks, vec![chunk1]); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/read.rs b/crates/polars/tests/it/io/parquet/arrow/read.rs new file mode 100644 index 000000000000..fa6502557d81 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/read.rs @@ -0,0 +1,159 @@ +use std::path::PathBuf; + +use polars_parquet::arrow::read::*; + +use super::*; +#[cfg(feature = "parquet")] +#[test] +fn all_types() -> PolarsResult<()> { + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); + + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + let reader = FileReader::new(reader, metadata.row_groups, schema, None, None, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 1); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2, 3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]) + ); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result, + &BinaryArray::::from_slice([[48], [49], [48], [49], [48], [49], [48], [49]]) + ); + + Ok(()) +} + +#[cfg(feature = "parquet")] +#[test] +fn all_types_chunked() -> PolarsResult<()> { + // this has one batch with 8 elements + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + // chunk it in 5 (so, (5,3)) + let reader = FileReader::new(reader, metadata.row_groups, schema, Some(5), None, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 2); + + assert_eq!(batches[0].len(), 5); + assert_eq!(batches[1].len(), 3); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2])); + + let result = batches[1].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0])); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result, + &BinaryArray::::from_slice([[48], [49], [48], [49], [48]]) + ); + + let result = batches[1].columns()[9] + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(result, &BinaryArray::::from_slice([[49], [48], [49]])); + + Ok(()) +} + +#[test] +fn read_int96_timestamps() -> PolarsResult<()> { + use std::collections::BTreeMap; + + let timestamp_data = &[ + 0x50, 0x41, 0x52, 0x31, 0x15, 0x04, 0x15, 0x48, 0x15, 0x3c, 0x4c, 0x15, 0x06, 0x15, 0x00, + 0x12, 0x00, 0x00, 0x24, 0x00, 0x00, 0x0d, 0x01, 0x08, 0x9f, 0xd5, 0x1f, 0x0d, 0x0a, 0x44, + 0x00, 0x00, 0x59, 0x68, 0x25, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, + 0xfb, 0x2a, 0x00, 0x15, 0x00, 0x15, 0x14, 0x15, 0x18, 0x2c, 0x15, 0x06, 0x15, 0x10, 0x15, + 0x06, 0x15, 0x06, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x24, 0x02, 0x00, 0x00, 0x00, 0x06, 0x01, + 0x02, 0x03, 0x24, 0x00, 0x26, 0x9e, 0x01, 0x1c, 0x15, 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, + 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, + 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, + 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x04, 0x19, 0x2c, 0x35, 0x00, 0x18, 0x06, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x15, + 0x02, 0x00, 0x15, 0x06, 0x25, 0x02, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x73, 0x00, 0x16, 0x06, 0x19, 0x1c, 0x19, 0x1c, 0x26, 0x9e, 0x01, 0x1c, 0x15, + 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, + 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, + 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, 0x16, 0x9e, 0x01, 0x16, 0x06, 0x26, 0x08, 0x16, 0x96, + 0x01, 0x14, 0x00, 0x00, 0x28, 0x20, 0x70, 0x61, 0x72, 0x71, 0x75, 0x65, 0x74, 0x2d, 0x63, + 0x70, 0x70, 0x2d, 0x61, 0x72, 0x72, 0x6f, 0x77, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x20, 0x31, 0x32, 0x2e, 0x30, 0x2e, 0x30, 0x19, 0x1c, 0x1c, 0x00, 0x00, 0x00, 0x95, + 0x00, 0x00, 0x00, 0x50, 0x41, 0x52, 0x31, + ]; + + let parse = |time_unit: TimeUnit| { + let mut reader = Cursor::new(timestamp_data); + let metadata = read_metadata(&mut reader)?; + let schema = arrow::datatypes::ArrowSchema { + fields: vec![arrow::datatypes::Field::new( + "timestamps", + arrow::datatypes::ArrowDataType::Timestamp(time_unit, None), + false, + )], + metadata: BTreeMap::new(), + }; + let reader = FileReader::new(reader, metadata.row_groups, schema, Some(5), None, None); + reader.collect::>>() + }; + + // This data contains int96 timestamps in the year 1000 and 3000, which are out of range for + // Timestamp(TimeUnit::Nanoseconds) and will cause a panic in dev builds/overflow in release builds + // However, the code should work for the Microsecond/Millisecond time units + for time_unit in [ + arrow::datatypes::TimeUnit::Microsecond, + arrow::datatypes::TimeUnit::Millisecond, + arrow::datatypes::TimeUnit::Second, + ] { + parse(time_unit).expect("Should not error"); + } + std::panic::catch_unwind(|| parse(arrow::datatypes::TimeUnit::Nanosecond)) + .expect_err("Should be a panic error"); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs b/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs new file mode 100644 index 000000000000..ec16f2c9a363 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs @@ -0,0 +1,257 @@ +use std::io::Cursor; + +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use polars_error::{PolarsError, PolarsResult}; +use polars_parquet::read::*; +use polars_parquet::write::*; + +/// Returns 2 sets of pages with different the same number of rows distributed un-evenly +fn pages( + arrays: &[&dyn Array], + encoding: Encoding, +) -> PolarsResult<(Vec, Vec, ArrowSchema)> { + // create pages with different number of rows + let array11 = PrimitiveArray::::from_slice([1, 2, 3, 4]); + let array12 = PrimitiveArray::::from_slice([5]); + let array13 = PrimitiveArray::::from_slice([6]); + + let schema = ArrowSchema::from(vec![ + Field::new("a1", ArrowDataType::Int64, false), + Field::new( + "a2", + arrays[0].data_type().clone(), + arrays.iter().map(|x| x.null_count()).sum::() != 0usize, + ), + ]); + + let parquet_schema = to_parquet_schema(&schema)?; + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let pages1 = [array11, array12, array13] + .into_iter() + .map(|array| { + array_to_page( + &array, + parquet_schema.columns()[0] + .descriptor + .primitive_type + .clone(), + &[Nested::Primitive(None, true, array.len())], + options, + Encoding::Plain, + ) + }) + .collect::>>()?; + + let pages2 = arrays + .iter() + .flat_map(|array| { + array_to_pages( + *array, + parquet_schema.columns()[1] + .descriptor + .primitive_type + .clone(), + &[Nested::Primitive(None, true, array.len())], + options, + encoding, + ) + .unwrap() + .collect::>>() + .unwrap() + }) + .collect::>(); + + Ok((pages1, pages2, schema)) +} + +/// Tests reading pages while skipping indexes +fn read_with_indexes( + (pages1, pages2, schema): (Vec, Vec, ArrowSchema), + expected: Box, +) -> PolarsResult<()> { + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let to_compressed = |pages: Vec| { + let encoded_pages = DynIter::new(pages.into_iter().map(Ok)); + let compressed_pages = + Compressor::new(encoded_pages, options.compression, vec![]).map_err(PolarsError::from); + PolarsResult::Ok(DynStreamingIterator::new(compressed_pages)) + }; + + let row_group = DynIter::new(vec![to_compressed(pages1), to_compressed(pages2)].into_iter()); + + let writer = vec![]; + let mut writer = FileWriter::try_new(writer, schema, options)?; + + writer.write(row_group)?; + writer.end(None)?; + let data = writer.into_inner(); + + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let schema = infer_schema(&metadata)?; + + // row group-based filtering can be done here + let row_groups = metadata.row_groups; + + // one per row group + let pages = row_groups + .iter() + .map(|row_group| { + assert!(indexes::has_indexes(row_group)); + + indexes::read_filtered_pages(&mut reader, row_group, &schema.fields, |_, intervals| { + let first_field = &intervals[0]; + let first_field_column = &first_field[0]; + assert_eq!(first_field_column.len(), 3); + let selection = [false, true, false]; + + first_field_column + .iter() + .zip(selection) + .filter(|(_i, is_selected)| *is_selected) + .map(|(i, _is_selected)| *i) + .collect() + }) + }) + .collect::>>()?; + + // apply projection pushdown + let schema = schema.filter(|index, _| index == 1); + let pages = pages + .into_iter() + .map(|pages| { + pages + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 1) + .map(|(_, pages)| pages) + .collect::>() + }) + .collect::>(); + + let expected = Chunk::new(vec![expected]); + + let chunks = FileReader::new( + reader, + row_groups, + schema, + Some(1024 * 8 * 8), + None, + Some(pages), + ); + + let arrays = chunks.collect::>>()?; + + assert_eq!(arrays, vec![expected]); + Ok(()) +} + +#[test] +fn indexed_required_i64() -> PolarsResult<()> { + let array21 = Int32Array::from_slice([1, 2, 3]); + let array22 = Int32Array::from_slice([4, 5, 6]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_i64() -> PolarsResult<()> { + let array21 = Int32Array::from([Some(1), Some(2), None]); + let array22 = Int32Array::from([None, Some(5), Some(6)]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_i64_delta() -> PolarsResult<()> { + let array21 = Int32Array::from([Some(1), Some(2), None]); + let array22 = Int32Array::from([None, Some(5), Some(6)]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes( + pages(&[&array21, &array22], Encoding::DeltaBinaryPacked)?, + expected, + ) +} + +#[test] +fn indexed_required_i64_delta() -> PolarsResult<()> { + let array21 = Int32Array::from_slice([1, 2, 3]); + let array22 = Int32Array::from_slice([4, 5, 6]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes( + pages(&[&array21, &array22], Encoding::DeltaBinaryPacked)?, + expected, + ) +} + +#[test] +fn indexed_required_fixed_len() -> PolarsResult<()> { + let array21 = FixedSizeBinaryArray::from_slice([[127], [128], [129]]); + let array22 = FixedSizeBinaryArray::from_slice([[130], [131], [132]]); + let expected = FixedSizeBinaryArray::from_slice([[131]]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_fixed_len() -> PolarsResult<()> { + let array21 = FixedSizeBinaryArray::from([Some([127]), Some([128]), None]); + let array22 = FixedSizeBinaryArray::from([None, Some([131]), Some([132])]); + let expected = FixedSizeBinaryArray::from_slice([[131]]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_required_boolean() -> PolarsResult<()> { + let array21 = BooleanArray::from_slice([true, false, true]); + let array22 = BooleanArray::from_slice([false, false, true]); + let expected = BooleanArray::from_slice([false]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_boolean() -> PolarsResult<()> { + let array21 = BooleanArray::from([Some(true), Some(false), None]); + let array22 = BooleanArray::from([None, Some(false), Some(true)]); + let expected = BooleanArray::from_slice([false]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_dict() -> PolarsResult<()> { + let indices = PrimitiveArray::from_values((0..6u64).map(|x| x % 2)); + let values = PrimitiveArray::from_slice([4i64, 6i64]).boxed(); + let array = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let indices = PrimitiveArray::from_slice([0u64]); + let values = PrimitiveArray::from_slice([4i64, 6i64]).boxed(); + let expected = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let expected = expected.boxed(); + + read_with_indexes(pages(&[&array], Encoding::RleDictionary)?, expected) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs b/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs new file mode 100644 index 000000000000..a577ee0efe7b --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs @@ -0,0 +1,115 @@ +use std::borrow::Borrow; +use std::io::Cursor; + +use arrow2::chunk::Chunk; +use arrow2::datatypes::{Field, Metadata, Schema}; +use arrow2::error::Result; +use arrow2::io::parquet::read as p_read; +use arrow2::io::parquet::write::*; +use sample_arrow2::array::ArbitraryArray; +use sample_arrow2::chunk::{ArbitraryChunk, ChainedChunk}; +use sample_arrow2::datatypes::{sample_flat, ArbitraryArrowDataType}; +use sample_std::{Chance, Random, Regex, Sample}; +use sample_test::sample_test; + +fn deep_chunk(depth: usize, len: usize) -> ArbitraryChunk { + let names = Regex::new("[a-z]{4,8}"); + let data_type = ArbitraryArrowDataType { + struct_branch: 1..3, + names: names.clone(), + // TODO: this breaks the test + // nullable: Chance(0.5), + nullable: Chance(0.0), + flat: sample_flat, + } + .sample_depth(depth); + + let array = ArbitraryArray { + names, + branch: 0..10, + len: len..(len + 1), + null: Chance(0.1), + // TODO: this breaks the test + // is_nullable: true, + is_nullable: false, + }; + + ArbitraryChunk { + // TODO: shrinking appears to be an issue with chunks this large. issues + // currently reproduce on the smaller sizes anyway. + // chunk_len: 10..1000, + chunk_len: 1..10, + array_count: 1..2, + data_type, + array, + } +} + +#[sample_test] +fn round_trip_sample( + #[sample(deep_chunk(5, 100).sample_one())] chained: ChainedChunk, +) -> Result<()> { + sample_test::env_logger_init(); + let chunks = vec![chained.value]; + let name = Regex::new("[a-z]{4, 8}"); + let mut g = Random::new(); + + // TODO: this probably belongs in a helper in sample-arrow2 + let schema = Schema { + fields: chunks + .first() + .unwrap() + .iter() + .map(|arr| { + Field::new( + name.generate(&mut g), + arr.data_type().clone(), + arr.validity().is_some(), + ) + }) + .collect(), + metadata: Metadata::default(), + }; + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V2, + data_pagesize_limit: None, + }; + + let encodings: Vec<_> = schema + .borrow() + .fields + .iter() + .map(|field| transverse(field.data_type(), |_| Encoding::Plain)) + .collect(); + + let row_groups = RowGroupIterator::try_new( + chunks.clone().into_iter().map(Ok), + &schema, + options, + encodings, + )?; + + let buffer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(buffer, schema, options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let mut buffer = writer.into_inner(); + + let metadata = p_read::read_metadata(&mut buffer)?; + let schema = p_read::infer_schema(&metadata)?; + + let mut reader = p_read::FileReader::new(buffer, metadata.row_groups, schema, None, None, None); + + let result: Vec<_> = reader.collect::>()?; + + assert_eq!(result, chunks); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs new file mode 100644 index 000000000000..2f1f35d9d456 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -0,0 +1,491 @@ +use polars_parquet::arrow::write::*; + +use super::*; + +fn round_trip( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, +) -> PolarsResult<()> { + round_trip_opt_stats(column, file, version, compression, encodings, true) +} + +fn round_trip_opt_stats( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, + check_stats: bool, +) -> PolarsResult<()> { + let (array, statistics) = match file { + "nested" => ( + pyarrow_nested_nullable(column), + pyarrow_nested_nullable_statistics(column), + ), + "nullable" => ( + pyarrow_nullable(column), + pyarrow_nullable_statistics(column), + ), + "required" => ( + pyarrow_required(column), + pyarrow_required_statistics(column), + ), + "struct" => (pyarrow_struct(column), pyarrow_struct_statistics(column)), + "nested_edge" => ( + pyarrow_nested_edge(column), + pyarrow_nested_edge_statistics(column), + ), + _ => unreachable!(), + }; + + let field = Field::new("a1", array.data_type().clone(), true); + let schema = ArrowSchema::from(vec![field]); + + let options = WriteOptions { + write_statistics: true, + compression, + version, + data_pagesize_limit: None, + }; + + let iter = vec![Chunk::try_new(vec![array.clone()])]; + + let row_groups = + RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(writer, schema, options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let (result, stats) = read_column(&mut Cursor::new(data), "a1")?; + + assert_eq!(array.as_ref(), result.as_ref()); + if check_stats { + assert_eq!(statistics, stats); + } + Ok(()) +} + +#[test] +fn int64_optional_v1() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_required_v1() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_v2() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_delta() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[test] +fn int64_required_delta() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn int64_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v1() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v2() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v1() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v1() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn bool_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v2() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v1() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v1() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v2() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v2() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v1() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_nested_inner_required_required_i64() -> PolarsResult<()> { + round_trip_opt_stats( + "list_nested_inner_required_required_i64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + false, + ) +} + +#[test] +fn v1_nested_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + true, + ) +} + +#[test] +fn v1_nested_list_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "list_struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + true, + ) +} + +#[test] +fn utf8_optional_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[test] +fn utf8_required_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn i64_optional_v2_dict_compressed() -> PolarsResult<()> { + round_trip( + "int32_dict", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::RleDictionary], + ) +} + +#[test] +fn struct_v1() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn struct_v2() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn nested_edge_simple() -> PolarsResult<()> { + round_trip( + "simple", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_null() -> PolarsResult<()> { + round_trip( + "null", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn v1_nested_edge_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_list_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "list_struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} diff --git a/crates/polars/tests/it/io/parquet/mod.rs b/crates/polars/tests/it/io/parquet/mod.rs new file mode 100644 index 000000000000..ba6bbe5dc724 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/mod.rs @@ -0,0 +1,207 @@ +#![forbid(unsafe_code)] +mod arrow; +mod read; +mod roundtrip; +mod write; + +use std::io::Cursor; +use std::path::PathBuf; + +use polars::prelude::*; + +// The dynamic representation of values in native Rust. This is not exhaustive. +// todo: maybe refactor this into serde/json? +#[derive(Debug, PartialEq)] +pub enum Array { + Int32(Vec>), + Int64(Vec>), + Int96(Vec>), + Float(Vec>), + Double(Vec>), + Boolean(Vec>), + Binary(Vec>>), + FixedLenBinary(Vec>>), + List(Vec>), + Struct(Vec, Vec), +} + +use std::sync::Arc; + +use polars_parquet::parquet::schema::types::{PhysicalType, PrimitiveType}; +use polars_parquet::parquet::statistics::*; + +pub fn alltypes_plain(column: &str) -> Array { + match column { + "id" => { + let expected = vec![4, 5, 6, 7, 2, 3, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "id-short-array" => { + let expected = vec![4]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bool_col" => { + let expected = vec![true, false, true, false, true, false, true, false]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Boolean(expected) + }, + "tinyint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "smallint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "int_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bigint_col" => { + let expected = vec![0, 10, 0, 10, 0, 10, 0, 10]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int64(expected) + }, + "float_col" => { + let expected = vec![0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Float(expected) + }, + "double_col" => { + let expected = vec![0.0, 10.1, 0.0, 10.1, 0.0, 10.1, 0.0, 10.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Double(expected) + }, + "date_string_col" => { + let expected = vec![ + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "string_col" => { + let expected = vec![ + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +pub fn alltypes_statistics(column: &str) -> Arc { + match column { + "id" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(7), + }), + "id-short-array" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(4), + max_value: Some(4), + }), + "bool_col" => Arc::new(BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(false), + max_value: Some(true), + }), + "tinyint_col" | "smallint_col" | "int_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(1), + }), + "bigint_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int64), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(10), + }), + "float_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Float), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(1.1), + }), + "double_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Double), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(10.1), + }), + "date_string_col" => Arc::new(BinaryStatistics { + primitive_type: PrimitiveType::from_physical( + "col".to_string(), + PhysicalType::ByteArray, + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48, 49, 47, 48, 49, 47, 48, 57]), + max_value: Some(vec![48, 52, 47, 48, 49, 47, 48, 57]), + }), + "string_col" => Arc::new(BinaryStatistics { + primitive_type: PrimitiveType::from_physical( + "col".to_string(), + PhysicalType::ByteArray, + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48]), + max_value: Some(vec![49]), + }), + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +#[test] +fn test_vstack_empty_3220() -> PolarsResult<()> { + let df1 = df! { + "a" => ["1", "2"], + "b" => [1, 2] + }?; + let empty_df = df1.head(Some(0)); + let mut stacked = df1.clone(); + stacked.vstack_mut(&empty_df)?; + stacked.vstack_mut(&df1)?; + let mut buf = Cursor::new(Vec::new()); + ParquetWriter::new(&mut buf).finish(&mut stacked)?; + let read_df = ParquetReader::new(buf).finish()?; + assert!(stacked.equals(&read_df)); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/binary.rs b/crates/polars/tests/it/io/parquet/read/binary.rs new file mode 100644 index 000000000000..b86e28ddf0ff --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/binary.rs @@ -0,0 +1,33 @@ +use polars_parquet::parquet::deserialize::FixedLenBinaryPageState; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::BinaryPageDict; +use super::utils::deserialize_optional; + +pub fn page_to_vec(page: &DataPage, dict: Option<&BinaryPageDict>) -> Result>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => values + .map(|x| Ok(x.to_vec())) + .map(Some) + .map(|x| x.transpose()) + .collect(), + FixedLenBinaryPageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).map(|x| x.to_vec()).map(Some))) + .collect(), + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).map(|x| x.to_vec()))); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/boolean.rs b/crates/polars/tests/it/io/parquet/read/boolean.rs new file mode 100644 index 000000000000..7642f4023fff --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/boolean.rs @@ -0,0 +1,20 @@ +use polars_parquet::parquet::deserialize::BooleanPageState; +use polars_parquet::parquet::encoding::hybrid_rle::BitmapIter; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::utils::deserialize_optional; + +pub fn page_to_vec(page: &DataPage) -> Result>> { + assert_eq!(page.descriptor.max_rep_level, 0); + let state = BooleanPageState::try_new(page)?; + + match state { + BooleanPageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + BooleanPageState::Required(bitmap, length) => { + Ok(BitmapIter::new(bitmap, 0, length).map(Some).collect()) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/deserialize.rs b/crates/polars/tests/it/io/parquet/read/deserialize.rs new file mode 100644 index 000000000000..1b5cf18b1452 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/deserialize.rs @@ -0,0 +1,314 @@ +use polars_parquet::parquet::deserialize::{ + FilteredHybridBitmapIter, FilteredHybridEncoded, HybridEncoded, +}; +use polars_parquet::parquet::indexes::Interval; + +#[test] +fn bitmap_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Bitmap(&[0b01000011], 7))].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 1, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Bitmap(&[0b01000011], 8))].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }] + ); +} + +#[test] +fn bitmap_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01000011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 7, + length: 1, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 3, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(4), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn bitmap_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 1, + }, + FilteredHybridEncoded::Skipped(3), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Repeated(true, 7))].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Repeated(true, 8))].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }] + ); +} + +#[test] +fn repeated_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 3, + } + ] + ); +} + +#[test] +fn repeated_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(8), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Skipped(7), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs new file mode 100644 index 000000000000..8b9bce7c50e7 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs @@ -0,0 +1,48 @@ +use polars_parquet::parquet::encoding::get_length; +use polars_parquet::parquet::error::Error; + +#[derive(Debug)] +pub struct BinaryPageDict { + values: Vec>, +} + +impl BinaryPageDict { + pub fn new(values: Vec>) -> Self { + Self { values } + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&[u8], Error> { + self.values + .get(index) + .map(|x| x.as_ref()) + .ok_or_else(|| Error::OutOfSpec("invalid index".to_string())) + } +} + +fn read_plain(bytes: &[u8], length: usize) -> Result>, Error> { + let mut bytes = bytes; + let mut values = Vec::new(); + + for _ in 0..length { + let slot_length = get_length(bytes).unwrap(); + bytes = &bytes[4..]; + + if slot_length > bytes.len() { + return Err(Error::OutOfSpec( + "The string on a dictionary page has a length that is out of bounds".to_string(), + )); + } + let (result, remaining) = bytes.split_at(slot_length); + + values.push(result.to_vec()); + bytes = remaining; + } + + Ok(values) +} + +pub fn read(buf: &[u8], num_values: usize) -> Result { + let values = read_plain(buf, num_values)?; + Ok(BinaryPageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs new file mode 100644 index 000000000000..31b150fcb820 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs @@ -0,0 +1,31 @@ +use polars_parquet::parquet::error::{Error, Result}; + +#[derive(Debug)] +pub struct FixedLenByteArrayPageDict { + values: Vec, + size: usize, +} + +impl FixedLenByteArrayPageDict { + pub fn new(values: Vec, size: usize) -> Self { + Self { values, size } + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&[u8]> { + self.values + .get(index * self.size..(index + 1) * self.size) + .ok_or_else(|| { + Error::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read(buf: &[u8], size: usize, num_values: usize) -> Result { + let length = size.saturating_mul(num_values); + let values = buf.get(..length).ok_or_else(|| Error::OutOfSpec("Fixed sized binary declares a number of values times size larger than the page buffer".to_string()))?.to_vec(); + + Ok(FixedLenByteArrayPageDict::new(values, size)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs new file mode 100644 index 000000000000..4dcb2afbaf70 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs @@ -0,0 +1,56 @@ +mod binary; +mod fixed_len_binary; +mod primitive; + +pub use binary::BinaryPageDict; +pub use fixed_len_binary::FixedLenByteArrayPageDict; +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::page::DictPage; +use polars_parquet::parquet::schema::types::PhysicalType; +pub use primitive::PrimitivePageDict; + +pub enum DecodedDictPage { + Int32(PrimitivePageDict), + Int64(PrimitivePageDict), + Int96(PrimitivePageDict<[u32; 3]>), + Float(PrimitivePageDict), + Double(PrimitivePageDict), + ByteArray(BinaryPageDict), + FixedLenByteArray(FixedLenByteArrayPageDict), +} + +pub fn deserialize(page: &DictPage, physical_type: PhysicalType) -> Result { + _deserialize(&page.buffer, page.num_values, page.is_sorted, physical_type) +} + +fn _deserialize( + buf: &[u8], + num_values: usize, + is_sorted: bool, + physical_type: PhysicalType, +) -> Result { + match physical_type { + PhysicalType::Boolean => Err(Error::OutOfSpec( + "Boolean physical type cannot be dictionary-encoded".to_string(), + )), + PhysicalType::Int32 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int32) + }, + PhysicalType::Int64 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int64) + }, + PhysicalType::Int96 => { + primitive::read::<[u32; 3]>(buf, num_values, is_sorted).map(DecodedDictPage::Int96) + }, + PhysicalType::Float => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Float) + }, + PhysicalType::Double => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Double) + }, + PhysicalType::ByteArray => binary::read(buf, num_values).map(DecodedDictPage::ByteArray), + PhysicalType::FixedLenByteArray(size) => { + fixed_len_binary::read(buf, size, num_values).map(DecodedDictPage::FixedLenByteArray) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs new file mode 100644 index 000000000000..aeeccf10eb5b --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs @@ -0,0 +1,47 @@ +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::types::{decode, NativeType}; + +#[derive(Debug)] +pub struct PrimitivePageDict { + values: Vec, +} + +impl PrimitivePageDict { + pub fn new(values: Vec) -> Self { + Self { values } + } + + pub fn values(&self) -> &[T] { + &self.values + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&T> { + self.values.get(index).ok_or_else(|| { + Error::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read( + buf: &[u8], + num_values: usize, + _is_sorted: bool, +) -> Result> { + let size_of = std::mem::size_of::(); + + let typed_size = num_values.wrapping_mul(size_of); + + let values = buf.get(..typed_size).ok_or_else(|| { + Error::OutOfSpec( + "The number of values declared in the dict page does not match the length of the page" + .to_string(), + ) + })?; + + let values = values.chunks_exact(size_of).map(decode::).collect(); + + Ok(PrimitivePageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/fixed_binary.rs b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs new file mode 100644 index 000000000000..b0e7472163f2 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs @@ -0,0 +1,34 @@ +use polars_parquet::parquet::deserialize::FixedLenBinaryPageState; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::FixedLenByteArrayPageDict; +use super::utils::deserialize_optional; + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&FixedLenByteArrayPageDict>, +) -> Result>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => { + Ok(values.map(|x| x.to_vec()).map(Some).collect()) + }, + FixedLenBinaryPageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).map(|x| x.to_vec()).map(Some))) + .collect(), + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).map(|x| x.to_vec()))); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/indexes.rs b/crates/polars/tests/it/io/parquet/read/indexes.rs new file mode 100644 index 000000000000..ad79c6d04544 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/indexes.rs @@ -0,0 +1,143 @@ +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::indexes::{ + BooleanIndex, BoundaryOrder, ByteIndex, Index, NativeIndex, PageIndex, PageLocation, +}; +use polars_parquet::parquet::read::{read_columns_indexes, read_metadata, read_pages_locations}; +use polars_parquet::parquet::schema::types::{ + FieldInfo, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, +}; +use polars_parquet::parquet::schema::Repetition; + +/* +import pyspark.sql # 3.2.1 +spark = pyspark.sql.SparkSession.builder.getOrCreate() +spark.conf.set("parquet.bloom.filter.enabled", True) +spark.conf.set("parquet.bloom.filter.expected.ndv", 10) +spark.conf.set("parquet.bloom.filter.max.bytes", 32) + +data = [(i, f"{i}", False) for i in range(10)] +df = spark.createDataFrame(data, ["id", "string", "bool"]).repartition(1) + +df.write.parquet("bla.parquet", mode = "overwrite") +*/ +const FILE: &[u8] = &[ + 80, 65, 82, 49, 21, 0, 21, 172, 1, 21, 138, 1, 21, 169, 161, 209, 137, 5, 28, 21, 20, 21, 0, + 21, 6, 21, 8, 0, 0, 86, 24, 2, 0, 0, 0, 20, 1, 0, 13, 1, 17, 9, 1, 22, 1, 1, 0, 3, 1, 5, 12, 0, + 0, 0, 4, 1, 5, 12, 0, 0, 0, 5, 1, 5, 12, 0, 0, 0, 6, 1, 5, 12, 0, 0, 0, 7, 1, 5, 72, 0, 0, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 21, 0, 21, 112, 21, 104, 21, 138, 239, 232, + 170, 15, 28, 21, 20, 21, 0, 21, 6, 21, 8, 0, 0, 56, 40, 2, 0, 0, 0, 20, 1, 1, 0, 0, 0, 48, 1, + 5, 0, 49, 1, 5, 0, 50, 1, 5, 0, 51, 1, 5, 0, 52, 1, 5, 0, 53, 1, 5, 60, 54, 1, 0, 0, 0, 55, 1, + 0, 0, 0, 56, 1, 0, 0, 0, 57, 21, 0, 21, 16, 21, 20, 21, 202, 209, 169, 227, 4, 28, 21, 20, 21, + 0, 21, 6, 21, 8, 0, 0, 8, 28, 2, 0, 0, 0, 20, 1, 0, 0, 25, 17, 2, 25, 24, 8, 0, 0, 0, 0, 0, 0, + 0, 0, 25, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 48, 25, 24, + 1, 57, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 0, 25, 24, 1, 0, 21, 2, 25, 22, 0, 0, 25, 28, + 22, 8, 21, 188, 1, 22, 0, 0, 0, 25, 28, 22, 196, 1, 21, 150, 1, 22, 0, 0, 0, 25, 28, 22, 218, + 2, 21, 66, 22, 0, 0, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 24, 130, 24, 8, + 134, 8, 68, 6, 2, 101, 128, 10, 64, 2, 38, 78, 114, 1, 64, 38, 1, 192, 194, 152, 64, 70, 0, 36, + 56, 121, 64, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 8, 17, 10, 29, 5, 88, 194, + 0, 35, 208, 25, 16, 70, 68, 48, 38, 17, 16, 140, 68, 98, 56, 0, 131, 4, 193, 40, 129, 161, 160, + 1, 96, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 76, 72, 12, 115, 112, + 97, 114, 107, 95, 115, 99, 104, 101, 109, 97, 21, 6, 0, 21, 4, 37, 2, 24, 2, 105, 100, 0, 21, + 12, 37, 2, 24, 6, 115, 116, 114, 105, 110, 103, 37, 0, 76, 28, 0, 0, 0, 21, 0, 37, 2, 24, 4, + 98, 111, 111, 108, 0, 22, 20, 25, 28, 25, 60, 38, 8, 28, 21, 4, 25, 53, 0, 6, 8, 25, 24, 2, + 105, 100, 21, 2, 22, 20, 22, 222, 1, 22, 188, 1, 38, 8, 60, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, + 8, 0, 0, 0, 0, 0, 0, 0, 0, 22, 0, 40, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, 8, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 226, 4, 0, 22, 158, 4, 21, 22, 22, 156, 3, 21, 62, 0, + 38, 196, 1, 28, 21, 12, 25, 53, 0, 6, 8, 25, 24, 6, 115, 116, 114, 105, 110, 103, 21, 2, 22, + 20, 22, 158, 1, 22, 150, 1, 38, 196, 1, 60, 54, 0, 40, 1, 57, 24, 1, 48, 0, 25, 28, 21, 0, 21, + 0, 21, 2, 0, 22, 192, 5, 0, 22, 180, 4, 21, 24, 22, 218, 3, 21, 34, 0, 38, 218, 2, 28, 21, 0, + 25, 53, 0, 6, 8, 25, 24, 4, 98, 111, 111, 108, 21, 2, 22, 20, 22, 62, 22, 66, 38, 218, 2, 60, + 24, 1, 0, 24, 1, 0, 22, 0, 40, 1, 0, 24, 1, 0, 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 158, 6, + 0, 22, 204, 4, 21, 22, 22, 252, 3, 21, 34, 0, 22, 186, 3, 22, 20, 38, 8, 22, 148, 3, 20, 0, 0, + 25, 44, 24, 24, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, + 118, 101, 114, 115, 105, 111, 110, 24, 5, 51, 46, 50, 46, 49, 0, 24, 41, 111, 114, 103, 46, 97, + 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 112, 97, 114, 113, + 117, 101, 116, 46, 114, 111, 119, 46, 109, 101, 116, 97, 100, 97, 116, 97, 24, 213, 1, 123, 34, + 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 117, 99, 116, 34, 44, 34, 102, 105, 101, 108, + 100, 115, 34, 58, 91, 123, 34, 110, 97, 109, 101, 34, 58, 34, 105, 100, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 108, 111, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, 108, 101, 34, + 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, 125, 125, 44, + 123, 34, 110, 97, 109, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, + 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, + 125, 125, 44, 123, 34, 110, 97, 109, 101, 34, 58, 34, 98, 111, 111, 108, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 98, 111, 111, 108, 101, 97, 110, 34, 44, 34, 110, 117, 108, 108, 97, 98, + 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, + 125, 125, 93, 125, 0, 24, 74, 112, 97, 114, 113, 117, 101, 116, 45, 109, 114, 32, 118, 101, + 114, 115, 105, 111, 110, 32, 49, 46, 49, 50, 46, 50, 32, 40, 98, 117, 105, 108, 100, 32, 55, + 55, 101, 51, 48, 99, 56, 48, 57, 51, 51, 56, 54, 101, 99, 53, 50, 99, 51, 99, 102, 97, 54, 99, + 51, 52, 98, 55, 101, 102, 51, 51, 50, 49, 51, 50, 50, 99, 57, 52, 41, 25, 60, 28, 0, 0, 28, 0, + 0, 28, 0, 0, 0, 182, 2, 0, 0, 80, 65, 82, 49, +]; + +#[test] +fn test() -> Result<(), Error> { + let mut reader = std::io::Cursor::new(FILE); + + let expected_index = vec![ + Box::new(NativeIndex:: { + primitive_type: PrimitiveType::from_physical("id".to_string(), PhysicalType::Int64), + indexes: vec![PageIndex { + min: Some(0), + max: Some(9), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }) as Box, + Box::new(ByteIndex { + primitive_type: PrimitiveType { + field_info: FieldInfo { + name: "string".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: Some(PrimitiveLogicalType::String), + converted_type: Some(PrimitiveConvertedType::Utf8), + physical_type: PhysicalType::ByteArray, + }, + indexes: vec![PageIndex { + min: Some(b"0".to_vec()), + max: Some(b"9".to_vec()), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }), + Box::new(BooleanIndex { + indexes: vec![PageIndex { + min: Some(false), + max: Some(false), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }), + ]; + let expected_page_locations = vec![ + vec![PageLocation { + offset: 4, + compressed_page_size: 94, + first_row_index: 0, + }], + vec![PageLocation { + offset: 98, + compressed_page_size: 75, + first_row_index: 0, + }], + vec![PageLocation { + offset: 173, + compressed_page_size: 33, + first_row_index: 0, + }], + ]; + + let metadata = read_metadata(&mut reader)?; + let columns = &metadata.row_groups[0].columns(); + + let indexes = read_columns_indexes(&mut reader, columns)?; + assert_eq!(&indexes, &expected_index); + + let pages = read_pages_locations(&mut reader, columns)?; + assert_eq!(pages, expected_page_locations); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/mod.rs b/crates/polars/tests/it/io/parquet/read/mod.rs new file mode 100644 index 000000000000..49ec00fa2e0f --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/mod.rs @@ -0,0 +1,432 @@ +/// Serialization to Rust's Native types. +/// In comparison to Arrow, this in-memory format does not leverage logical types nor SIMD operations, +/// but OTOH it has no external dependencies and is very familiar to Rust developers. +mod binary; +mod boolean; +mod deserialize; +mod dictionary; +mod fixed_binary; +mod indexes; +mod primitive; +mod primitive_nested; +mod struct_; +mod utils; + +use std::fs::File; + +use dictionary::{deserialize as deserialize_dict, DecodedDictPage}; +#[cfg(feature = "async")] +use futures::StreamExt; +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::metadata::ColumnChunkMetaData; +use polars_parquet::parquet::page::{CompressedPage, DataPage, Page}; +#[cfg(feature = "async")] +use polars_parquet::parquet::read::get_page_stream; +#[cfg(feature = "async")] +use polars_parquet::parquet::read::read_metadata_async; +use polars_parquet::parquet::read::{ + get_column_iterator, get_field_columns, read_metadata, BasicDecompressor, MutStreamingIterator, + State, +}; +use polars_parquet::parquet::schema::types::{GroupConvertedType, ParquetType}; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::types::int96_to_i64_ns; +use polars_parquet::parquet::FallibleStreamingIterator; + +use super::*; + +pub fn get_path() -> PathBuf { + let dir = env!("CARGO_MANIFEST_DIR"); + PathBuf::from(dir).join("../../docs/data") +} + +/// Reads a page into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn page_to_array(page: &DataPage, dict: Option<&DecodedDictPage>) -> Result { + let physical_type = page.descriptor.primitive_type.physical_type; + match page.descriptor.max_rep_level { + 0 => match physical_type { + PhysicalType::Boolean => Ok(Array::Boolean(boolean::page_to_vec(page)?)), + PhysicalType::Int32 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int32(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int32) + }, + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int64) + }, + PhysicalType::Int96 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int96(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int96) + }, + PhysicalType::Float => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Float(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Float) + }, + PhysicalType::Double => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Double(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Double) + }, + PhysicalType::ByteArray => { + let dict = dict.map(|dict| { + if let DecodedDictPage::ByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + binary::page_to_vec(page, dict).map(Array::Binary) + }, + PhysicalType::FixedLenByteArray(_) => { + let dict = dict.map(|dict| { + if let DecodedDictPage::FixedLenByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + fixed_binary::page_to_vec(page, dict).map(Array::FixedLenBinary) + }, + }, + _ => match dict { + None => match physical_type { + PhysicalType::Int64 => Ok(primitive_nested::page_to_array::(page, None)?), + _ => todo!(), + }, + Some(_) => match physical_type { + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + Ok(primitive_nested::page_dict_to_array(page, dict)?) + }, + _ => todo!(), + }, + }, + } +} + +pub fn collect>( + mut iterator: I, + type_: PhysicalType, +) -> Result> { + let mut arrays = vec![]; + let mut dict = None; + while let Some(page) = iterator.next()? { + match page { + Page::Data(page) => arrays.push(page_to_array(page, dict.as_ref())?), + Page::Dict(page) => { + dict = Some(deserialize_dict(page, type_)?); + }, + } + } + Ok(arrays) +} + +/// Reads columns into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn columns_to_array(mut columns: I, field: &ParquetType) -> Result +where + II: Iterator>, + I: MutStreamingIterator, +{ + let mut validity = vec![]; + let mut has_filled = false; + let mut arrays = vec![]; + while let State::Some(mut new_iter) = columns.advance()? { + if let Some((pages, column)) = new_iter.get() { + let mut iterator = BasicDecompressor::new(pages, vec![]); + + let mut dict = None; + while let Some(page) = iterator.next()? { + match page { + polars_parquet::parquet::page::Page::Data(page) => { + if !has_filled { + struct_::extend_validity(&mut validity, page)?; + } + arrays.push(page_to_array(page, dict.as_ref())?) + }, + polars_parquet::parquet::page::Page::Dict(page) => { + dict = Some(deserialize_dict(page, column.physical_type())?); + }, + } + } + } + has_filled = true; + columns = new_iter; + } + + match field { + ParquetType::PrimitiveType { .. } => { + arrays.pop().ok_or_else(|| Error::OutOfSpec("".to_string())) + }, + ParquetType::GroupType { converted_type, .. } => { + if let Some(converted_type) = converted_type { + match converted_type { + GroupConvertedType::List => Ok(arrays.pop().unwrap()), + _ => todo!(), + } + } else { + Ok(Array::Struct(arrays, validity)) + } + }, + } +} + +pub fn read_column( + reader: &mut R, + row_group: usize, + field_name: &str, +) -> Result<(Array, Option>)> { + let metadata = read_metadata(reader)?; + + let field = metadata + .schema() + .fields() + .iter() + .find(|field| field.name() == field_name) + .ok_or_else(|| Error::OutOfSpec("column does not exist".to_string()))?; + + let columns = get_column_iterator( + reader, + &metadata.row_groups[row_group], + field.name(), + None, + vec![], + usize::MAX, + ); + + let mut statistics = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .map(|column_meta| column_meta.statistics().transpose()) + .collect::>>()?; + + let array = columns_to_array(columns, field)?; + + Ok((array, statistics.pop().unwrap())) +} + +#[cfg(feature = "async")] +pub async fn read_column_async< + R: futures::AsyncRead + futures::AsyncSeek + Send + std::marker::Unpin, +>( + reader: &mut R, + row_group: usize, + field_name: &str, +) -> Result<(Array, Option>)> { + let metadata = read_metadata_async(reader).await?; + + let field = metadata + .schema() + .fields() + .iter() + .find(|field| field.name() == field_name) + .ok_or_else(|| Error::OutOfSpec("column does not exist".to_string()))?; + + let column = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .next() + .unwrap(); + + let pages = get_page_stream(column, reader, vec![], Arc::new(|_, _| true), usize::MAX).await?; + + let mut statistics = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .map(|column_meta| column_meta.statistics().transpose()) + .collect::>>()?; + + let pages = pages.collect::>().await; + + let iterator = BasicDecompressor::new(pages.into_iter(), vec![]); + + let mut arrays = collect(iterator, column.physical_type())?; + + Ok((arrays.pop().unwrap(), statistics.pop().unwrap())) +} + +fn get_column(path: &str, column: &str) -> Result<(Array, Option>)> { + let mut file = File::open(path).unwrap(); + read_column(&mut file, 0, column) +} + +fn test_column(column: &str) -> Result<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + let (result, statistics) = get_column(path, column)?; + // the file does not have statistics + assert_eq!(statistics.as_ref().map(|x| x.as_ref()), None); + assert_eq!(result, alltypes_plain(column)); + Ok(()) +} + +#[test] +fn int32() -> Result<()> { + test_column("id") +} + +#[test] +fn bool() -> Result<()> { + test_column("bool_col") +} + +#[test] +fn tinyint_col() -> Result<()> { + test_column("tinyint_col") +} + +#[test] +fn smallint_col() -> Result<()> { + test_column("smallint_col") +} + +#[test] +fn int_col() -> Result<()> { + test_column("int_col") +} + +#[test] +fn bigint_col() -> Result<()> { + test_column("bigint_col") +} + +#[test] +fn float_col() -> Result<()> { + test_column("float_col") +} + +#[test] +fn double_col() -> Result<()> { + test_column("double_col") +} + +#[test] +fn timestamp_col() -> Result<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + + let expected = vec![ + 1235865600000000000i64, + 1235865660000000000, + 1238544000000000000, + 1238544060000000000, + 1233446400000000000, + 1233446460000000000, + 1230768000000000000, + 1230768060000000000, + ]; + + let expected = expected.into_iter().map(Some).collect::>(); + let (array, _) = get_column(path, "timestamp_col")?; + if let Array::Int96(array) = array { + let a = array + .into_iter() + .map(|x| x.map(int96_to_i64_ns)) + .collect::>(); + assert_eq!(expected, a); + } else { + panic!("Timestamp expected"); + }; + Ok(()) +} + +#[test] +fn test_metadata() -> Result<()> { + let mut testdata = get_path(); + testdata.push("alltypes_plain.parquet"); + let mut file = File::open(testdata).unwrap(); + + let metadata = read_metadata(&mut file)?; + + let columns = metadata.schema_descr.columns(); + + /* + from pyarrow: + required group field_id=0 schema { + optional int32 field_id=1 id; + optional boolean field_id=2 bool_col; + optional int32 field_id=3 tinyint_col; + optional int32 field_id=4 smallint_col; + optional int32// pub enum Value { + // UInt32(Option), + // Int32(Option), + // Int64(Option), + // Int96(Option<[u32; 3]>), + // Float32(Option), + // Float64(Option), + // Boolean(Option), + // Binary(Option>), + // FixedLenBinary(Option>), + // List(Option), + // } + field_id=5 int_col; + optional int64 field_id=6 bigint_col; + optional float field_id=7 float_col; + optional double field_id=8 double_col; + optional binary field_id=9 date_string_col; + optional binary field_id=10 string_col; + optional int96 field_id=11 timestamp_col; + } + */ + let expected = vec![ + PhysicalType::Int32, + PhysicalType::Boolean, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int64, + PhysicalType::Float, + PhysicalType::Double, + PhysicalType::ByteArray, + PhysicalType::ByteArray, + PhysicalType::Int96, + ]; + + let result = columns + .iter() + .map(|column| { + assert_eq!( + column.descriptor.primitive_type.field_info.repetition, + Repetition::Optional + ); + column.descriptor.primitive_type.physical_type + }) + .collect::>(); + + assert_eq!(expected, result); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive.rs b/crates/polars/tests/it/io/parquet/read/primitive.rs new file mode 100644 index 000000000000..0146e6a94835 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive.rs @@ -0,0 +1,116 @@ +use polars_parquet::parquet::deserialize::{ + native_cast, Casted, HybridRleDecoderIter, HybridRleIter, NativePageState, OptionalValues, + SliceFilteredIter, +}; +use polars_parquet::parquet::encoding::hybrid_rle::Decoder; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::types::NativeType; + +use super::dictionary::PrimitivePageDict; +use super::utils::deserialize_optional; + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub enum FilteredPageState<'a, T> +where + T: NativeType, +{ + /// A page of optional values + Optional(SliceFilteredIter, Casted<'a, T>>>), + /// A page of required values + Required(SliceFilteredIter>), +} + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum PageState<'a, T> +where + T: NativeType, +{ + Nominal(NativePageState<'a, T, &'a PrimitivePageDict>), + Filtered(FilteredPageState<'a, T>), +} + +impl<'a, T: NativeType> PageState<'a, T> { + /// Tries to create [`NativePageState`] + /// # Error + /// Errors iff the page is not a `NativePageState` + pub fn try_new( + page: &'a DataPage, + dict: Option<&'a PrimitivePageDict>, + ) -> Result { + if let Some(selected_rows) = page.selected_rows() { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + match (page.encoding(), dict, is_optional) { + (Encoding::Plain, _, true) => { + let (_, def_levels, _) = split_buffer(page)?; + + let validity = HybridRleDecoderIter::new(HybridRleIter::new( + Decoder::new(def_levels, 1), + page.num_values(), + )); + let values = native_cast(page)?; + + // validity and values interleaved. + let values = OptionalValues::new(validity, values); + + let values = + SliceFilteredIter::new(values, selected_rows.iter().copied().collect()); + + Ok(Self::Filtered(FilteredPageState::Optional(values))) + }, + (Encoding::Plain, _, false) => { + let values = SliceFilteredIter::new( + native_cast(page)?, + selected_rows.iter().copied().collect(), + ); + Ok(Self::Filtered(FilteredPageState::Required(values))) + }, + _ => Err(Error::FeatureNotSupported(format!( + "Viewing page for encoding {:?} for native type {}", + page.encoding(), + std::any::type_name::() + ))), + } + } else { + NativePageState::try_new(page, dict).map(Self::Nominal) + } + } +} + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result>, Error> { + assert_eq!(page.descriptor.max_rep_level, 0); + let state = PageState::::try_new(page, dict)?; + + match state { + PageState::Nominal(state) => match state { + NativePageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + NativePageState::Required(values) => Ok(values.map(Some).collect()), + NativePageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).copied().map(Some))) + .collect(), + NativePageState::OptionalDictionary(validity, dict) => { + let values = dict + .indexes + .map(|x| x.and_then(|x| dict.dict.value(x as usize).copied())); + deserialize_optional(validity, values) + }, + }, + PageState::Filtered(state) => match state { + FilteredPageState::Optional(values) => values.collect(), + FilteredPageState::Required(values) => Ok(values.map(Some).collect()), + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs new file mode 100644 index 000000000000..6ab120ee2c33 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -0,0 +1,230 @@ +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::encoding::{bitpacked, uleb128, Encoding}; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::read::levels::get_bit_width; +use polars_parquet::parquet::types::NativeType; + +use super::dictionary::PrimitivePageDict; +use super::Array; + +fn read_buffer(values: &[u8]) -> impl Iterator + '_ { + let chunks = values.chunks_exact(std::mem::size_of::()); + chunks.map(|chunk| { + // unwrap is infalible due to the chunk size. + let chunk: T::Bytes = match chunk.try_into() { + Ok(v) => v, + Err(_) => panic!(), + }; + T::from_le_bytes(chunk) + }) +} + +// todo: generalize i64 -> T +fn compose_array< + I: Iterator>, + F: Iterator>, + G: Iterator, +>( + rep_levels: I, + def_levels: F, + max_rep: u32, + max_def: u32, + mut values: G, +) -> Result { + let mut outer = vec![]; + let mut inner = vec![]; + + assert_eq!(max_rep, 1); + assert_eq!(max_def, 3); + let mut prev_def = 0; + rep_levels + .into_iter() + .zip(def_levels.into_iter()) + .try_for_each(|(rep, def)| { + let rep = rep?; + let def = def?; + match rep { + 1 => {}, + 0 => { + if prev_def > 1 { + let old = std::mem::take(&mut inner); + outer.push(Some(Array::Int64(old))); + } + }, + _ => unreachable!(), + } + match def { + 3 => inner.push(Some(values.next().unwrap())), + 2 => inner.push(None), + 1 => outer.push(Some(Array::Int64(vec![]))), + 0 => outer.push(None), + _ => unreachable!(), + } + prev_def = def; + Ok::<(), Error>(()) + })?; + outer.push(Some(Array::Int64(inner))); + Ok(Array::List(outer)) +} + +fn read_array_impl>( + rep_levels: &[u8], + def_levels: &[u8], + values: I, + length: usize, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let max_rep_level = rep_level_encoding.1 as u32; + let max_def_level = def_level_encoding.1 as u32; + + match ( + (rep_level_encoding.0, max_rep_level == 0), + (def_level_encoding.0, max_def_level == 0), + ) { + ((Encoding::Rle, true), (Encoding::Rle, true)) => compose_array( + std::iter::repeat(Ok(0)).take(length), + std::iter::repeat(Ok(0)).take(length), + max_rep_level, + max_def_level, + values, + ), + ((Encoding::Rle, false), (Encoding::Rle, true)) => { + let num_bits = get_bit_width(rep_level_encoding.1); + let rep_levels = HybridRleDecoder::try_new(rep_levels, num_bits, length)?; + compose_array( + rep_levels, + std::iter::repeat(Ok(0)).take(length), + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, true), (Encoding::Rle, false)) => { + let num_bits = get_bit_width(def_level_encoding.1); + let def_levels = HybridRleDecoder::try_new(def_levels, num_bits, length)?; + compose_array( + std::iter::repeat(Ok(0)).take(length), + def_levels, + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, false), (Encoding::Rle, false)) => { + let rep_levels = + HybridRleDecoder::try_new(rep_levels, get_bit_width(rep_level_encoding.1), length)?; + let def_levels = + HybridRleDecoder::try_new(def_levels, get_bit_width(def_level_encoding.1), length)?; + compose_array(rep_levels, def_levels, max_rep_level, max_def_level, values) + }, + _ => todo!(), + } +} + +fn read_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let values = read_buffer::(values); + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + let (rep_levels, def_levels, values) = split_buffer(page)?; + + match (&page.encoding(), dict) { + (Encoding::Plain, None) => read_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + _ => todo!(), + } +} + +fn read_dict_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + dict: &PrimitivePageDict, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let dict_values = dict.values(); + + let bit_width = values[0]; + let values = &values[1..]; + + let (_, consumed) = uleb128::decode(values)?; + let values = &values[consumed..]; + + let indices = bitpacked::Decoder::::try_new(values, bit_width as usize, length as usize)?; + + let values = indices.map(|id| dict_values[id as usize]); + + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_dict_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + assert_eq!(page.descriptor.max_rep_level, 1); + + let (rep_levels, def_levels, values) = split_buffer(page)?; + + match (page.encoding(), dict) { + (Encoding::PlainDictionary, Some(dict)) => read_dict_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + dict, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + (_, None) => Err(Error::OutOfSpec( + "A dictionary-encoded page MUST be preceded by a dictionary page".to_string(), + )), + _ => todo!(), + } +} diff --git a/crates/polars/tests/it/io/parquet/read/struct_.rs b/crates/polars/tests/it/io/parquet/read/struct_.rs new file mode 100644 index 000000000000..29e67a74c7ca --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/struct_.rs @@ -0,0 +1,27 @@ +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::read::levels::get_bit_width; + +pub fn extend_validity(val: &mut Vec, page: &DataPage) -> Result<(), Error> { + let (_, def_levels, _) = split_buffer(page)?; + let length = page.num_values(); + + if page.descriptor.max_def_level == 0 { + return Ok(()); + } + + let def_level_encoding = ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ); + + let mut def_levels = + HybridRleDecoder::try_new(def_levels, get_bit_width(def_level_encoding.1), length)?; + + val.reserve(length); + def_levels.try_for_each(|x| { + val.push(x? != 0); + Ok(()) + }) +} diff --git a/crates/polars/tests/it/io/parquet/read/utils.rs b/crates/polars/tests/it/io/parquet/read/utils.rs new file mode 100644 index 000000000000..7214417d81a1 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/utils.rs @@ -0,0 +1,66 @@ +use polars_parquet::parquet::deserialize::{ + DefLevelsDecoder, HybridDecoderBitmapIter, HybridEncoded, +}; +use polars_parquet::parquet::encoding::hybrid_rle::{BitmapIter, HybridRleDecoder}; +use polars_parquet::parquet::error::Error; + +pub fn deserialize_optional>>( + validity: DefLevelsDecoder, + values: I, +) -> Result>, Error> { + match validity { + DefLevelsDecoder::Bitmap(bitmap) => deserialize_bitmap(bitmap, values), + DefLevelsDecoder::Levels(levels, max_level) => { + deserialize_levels(levels, max_level, values) + }, + } +} + +fn deserialize_bitmap>>( + mut validity: HybridDecoderBitmapIter, + mut values: I, +) -> Result>, Error> { + let mut deserialized = Vec::with_capacity(validity.len()); + + validity.try_for_each(|run| match run? { + HybridEncoded::Bitmap(bitmap, length) => { + BitmapIter::new(bitmap, 0, length).try_for_each(|x| { + if x { + deserialized.push(values.next().transpose()?); + } else { + deserialized.push(None); + } + Result::<_, Error>::Ok(()) + }) + }, + HybridEncoded::Repeated(is_set, length) => { + if is_set { + deserialized.reserve(length); + for x in values.by_ref().take(length) { + deserialized.push(Some(x?)) + } + } else { + deserialized.extend(std::iter::repeat(None).take(length)) + } + Ok(()) + }, + })?; + Ok(deserialized) +} + +fn deserialize_levels>>( + levels: HybridRleDecoder, + max: u32, + mut values: I, +) -> Result>, Error> { + levels + .into_iter() + .map(|x| { + if x? == max { + values.next().transpose() + } else { + Ok(None) + } + }) + .collect() +} diff --git a/crates/polars-parquet/tests/it/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs similarity index 100% rename from crates/polars-parquet/tests/it/roundtrip.rs rename to crates/polars/tests/it/io/parquet/roundtrip.rs diff --git a/crates/polars/tests/it/io/parquet/write/binary.rs b/crates/polars/tests/it/io/parquet/write/binary.rs new file mode 100644 index 000000000000..add477530fec --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/binary.rs @@ -0,0 +1,87 @@ +use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::{serialize_statistics, BinaryStatistics, Statistics}; +use polars_parquet::parquet::types::ord_binary; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option>]) -> Result<(Vec, Vec)> { + // leave the first 4 bytes anouncing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(&(item.len() as i32).to_le_bytes()); + values.extend_from_slice(item.as_ref()); + true + } else { + false + } + }); + encode_bool(&mut validity, iter)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option>], + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &BinaryStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .cloned(), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .cloned(), + } as &dyn Statistics; + Some(serialize_statistics(statistics)) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + buffer, + descriptor.clone(), + Some(array.len()), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/indexes.rs b/crates/polars/tests/it/io/parquet/write/indexes.rs new file mode 100644 index 000000000000..44c13b55ca50 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/indexes.rs @@ -0,0 +1,133 @@ +use std::io::Cursor; + +use polars_parquet::parquet::compression::CompressionOptions; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::indexes::{ + select_pages, BoundaryOrder, Index, Interval, NativeIndex, PageIndex, PageLocation, +}; +use polars_parquet::parquet::metadata::SchemaDescriptor; +use polars_parquet::parquet::read::{ + read_columns_indexes, read_metadata, read_pages_locations, BasicDecompressor, IndexedPageReader, +}; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType, PrimitiveType}; +use polars_parquet::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, +}; + +use super::super::read::collect; +use super::primitive::array_to_page_v1; +use super::Array; + +fn write_file() -> Result> { + let page1 = vec![Some(0), Some(1), None, Some(3), Some(4), Some(5), Some(6)]; + let page2 = vec![Some(10), Some(11)]; + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "col1".to_string(), + PhysicalType::Int32, + )], + ); + + let pages = vec![ + array_to_page_v1::(&page1, &options, &schema.columns()[0].descriptor), + array_to_page_v1::(&page2, &options, &schema.columns()[0].descriptor), + ]; + + let pages = DynStreamingIterator::new(Compressor::new( + DynIter::new(pages.into_iter()), + CompressionOptions::Uncompressed, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + Ok(writer.into_inner().into_inner()) +} + +#[test] +fn read_indexed_page() -> Result<()> { + let data = write_file()?; + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let column = 0; + let columns = &metadata.row_groups[0].columns(); + + // selected the rows + let intervals = &[Interval::new(2, 2)]; + + let pages = read_pages_locations(&mut reader, columns)?; + + let pages = select_pages(intervals, &pages[column], metadata.row_groups[0].num_rows())?; + + let pages = IndexedPageReader::new(reader, &columns[column], pages, vec![], vec![]); + + let pages = BasicDecompressor::new(pages, vec![]); + + let arrays = collect(pages, columns[column].physical_type())?; + + // the second item and length 2 + assert_eq!(arrays, vec![Array::Int32(vec![None, Some(3)])]); + + Ok(()) +} + +#[test] +fn read_indexes_and_locations() -> Result<()> { + let data = write_file()?; + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let columns = &metadata.row_groups[0].columns(); + + let expected_page_locations = vec![vec![ + PageLocation { + offset: 4, + compressed_page_size: 63, + first_row_index: 0, + }, + PageLocation { + offset: 67, + compressed_page_size: 47, + first_row_index: 7, + }, + ]]; + let expected_index = vec![Box::new(NativeIndex:: { + primitive_type: PrimitiveType::from_physical("col1".to_string(), PhysicalType::Int32), + indexes: vec![ + PageIndex { + min: Some(0), + max: Some(6), + null_count: Some(1), + }, + PageIndex { + min: Some(10), + max: Some(11), + null_count: Some(0), + }, + ], + boundary_order: BoundaryOrder::Unordered, + }) as Box]; + + let indexes = read_columns_indexes(&mut reader, columns)?; + assert_eq!(&indexes, &expected_index); + + let pages = read_pages_locations(&mut reader, columns)?; + assert_eq!(pages, expected_page_locations); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/write/mod.rs b/crates/polars/tests/it/io/parquet/write/mod.rs new file mode 100644 index 000000000000..dbb90b7a87b7 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/mod.rs @@ -0,0 +1,289 @@ +mod binary; +mod indexes; +mod primitive; +mod sidecar; + +use std::io::{Cursor, Read, Seek}; +use std::sync::Arc; + +use polars_parquet::parquet::compression::{BrotliLevel, CompressionOptions}; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::{Descriptor, SchemaDescriptor}; +use polars_parquet::parquet::page::Page; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::statistics::Statistics; +#[cfg(feature = "async")] +use polars_parquet::parquet::write::FileStreamer; +use polars_parquet::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, +}; +use polars_parquet::read::read_metadata; +use primitive::array_to_page_v1; + +use super::{alltypes_plain, alltypes_statistics, Array}; + +pub fn array_to_page( + array: &Array, + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + // using plain encoding format + match array { + Array::Int32(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int64(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int96(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Float(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Double(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Binary(array) => binary::array_to_page_v1(array, options, descriptor), + _ => todo!(), + } +} + +fn read_column(reader: &mut R) -> Result<(Array, Option>)> { + let (a, statistics) = super::read::read_column(reader, 0, "col")?; + Ok((a, statistics)) +} + +#[cfg(feature = "async")] +#[allow(dead_code)] +async fn read_column_async< + R: futures::AsyncRead + futures::AsyncSeek + Send + std::marker::Unpin, +>( + reader: &mut R, +) -> Result<(Array, Option>)> { + let (a, statistics) = super::read::read_column_async(reader, 0, "col").await?; + Ok((a, statistics)) +} + +fn test_column(column: &str, compression: CompressionOptions) -> Result<()> { + let array = alltypes_plain(column); + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + // prepare schema + let type_ = match array { + Array::Int32(_) => PhysicalType::Int32, + Array::Int64(_) => PhysicalType::Int64, + Array::Int96(_) => PhysicalType::Int96, + Array::Float(_) => PhysicalType::Float, + Array::Double(_) => PhysicalType::Double, + Array::Binary(_) => PhysicalType::ByteArray, + _ => todo!(), + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical("col".to_string(), type_)], + ); + + let a = schema.columns(); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page( + &array, + &options, + &a[0].descriptor, + ))), + compression, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let (result, statistics) = read_column(&mut Cursor::new(data))?; + assert_eq!(array, result); + let stats = alltypes_statistics(column); + assert_eq!( + statistics.as_ref().map(|x| x.as_ref()), + Some(stats).as_ref().map(|x| x.as_ref()) + ); + Ok(()) +} + +#[test] +fn int32() -> Result<()> { + test_column("id", CompressionOptions::Uncompressed) +} + +#[test] +fn int32_snappy() -> Result<()> { + test_column("id", CompressionOptions::Snappy) +} + +#[test] +fn int32_lz4() -> Result<()> { + test_column("id", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_lz4_short_i32_array() -> Result<()> { + test_column("id-short-array", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_brotli() -> Result<()> { + test_column( + "id", + CompressionOptions::Brotli(Some(BrotliLevel::default())), + ) +} + +#[test] +#[ignore = "Native boolean writer not yet implemented"] +fn bool() -> Result<()> { + test_column("bool_col", CompressionOptions::Uncompressed) +} + +#[test] +fn tinyint() -> Result<()> { + test_column("tinyint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn smallint_col() -> Result<()> { + test_column("smallint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn int_col() -> Result<()> { + test_column("int_col", CompressionOptions::Uncompressed) +} + +#[test] +fn bigint_col() -> Result<()> { + test_column("bigint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn float_col() -> Result<()> { + test_column("float_col", CompressionOptions::Uncompressed) +} + +#[test] +fn double_col() -> Result<()> { + test_column("double_col", CompressionOptions::Uncompressed) +} + +#[test] +fn basic() -> Result<()> { + let array = vec![ + Some(0), + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + ]; + + let options = WriteOptions { + write_statistics: false, + version: Version::V1, + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "col".to_string(), + PhysicalType::Int32, + )], + ); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page_v1( + &array, + &options, + &schema.columns()[0].descriptor, + ))), + CompressionOptions::Uncompressed, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + // validated against an equivalent array produced by pyarrow. + let expected = 51; + assert_eq!( + metadata.row_groups[0].columns()[0].uncompressed_size(), + expected + ); + + Ok(()) +} + +#[cfg(feature = "async")] +#[allow(dead_code)] +async fn test_column_async(column: &str, compression: CompressionOptions) -> Result<()> { + let array = alltypes_plain(column); + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + // prepare schema + let type_ = match array { + Array::Int32(_) => PhysicalType::Int32, + Array::Int64(_) => PhysicalType::Int64, + Array::Int96(_) => PhysicalType::Int96, + Array::Float(_) => PhysicalType::Float, + Array::Double(_) => PhysicalType::Double, + Array::Binary(_) => PhysicalType::ByteArray, + _ => todo!(), + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical("col".to_string(), type_)], + ); + + let a = schema.columns(); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page( + &array, + &options, + &a[0].descriptor, + ))), + compression, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = futures::io::Cursor::new(vec![]); + let mut writer = FileStreamer::new(writer, schema, options, None); + + writer.write(DynIter::new(columns)).await?; + writer.end(None).await?; + + let data = writer.into_inner().into_inner(); + + let (result, statistics) = read_column_async(&mut futures::io::Cursor::new(data)).await?; + assert_eq!(array, result); + let stats = alltypes_statistics(column); + assert_eq!( + statistics.as_ref().map(|x| x.as_ref()), + Some(stats).as_ref().map(|x| x.as_ref()) + ); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/write/primitive.rs b/crates/polars/tests/it/io/parquet/write/primitive.rs new file mode 100644 index 000000000000..9cab7f0977f9 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/primitive.rs @@ -0,0 +1,78 @@ +use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::{serialize_statistics, PrimitiveStatistics, Statistics}; +use polars_parquet::parquet::types::NativeType; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option]) -> Result<(Vec, Vec)> { + // leave the first 4 bytes anouncing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(item.to_le_bytes().as_ref()); + true + } else { + false + } + }); + encode_bool(&mut validity, iter)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option], + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &PrimitiveStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array.iter().flatten().max_by(|x, y| x.ord(y)).copied(), + min_value: array.iter().flatten().min_by(|x, y| x.ord(y)).copied(), + } as &dyn Statistics; + Some(serialize_statistics(statistics)) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + buffer, + descriptor.clone(), + Some(array.len()), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/sidecar.rs b/crates/polars/tests/it/io/parquet/write/sidecar.rs new file mode 100644 index 000000000000..f1c654a23d98 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/sidecar.rs @@ -0,0 +1,55 @@ +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::metadata::SchemaDescriptor; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::write::{write_metadata_sidecar, FileWriter, Version, WriteOptions}; + +#[test] +fn basic() -> Result<(), Error> { + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "c1".to_string(), + PhysicalType::Int32, + )], + ); + + let mut metadatas = vec![]; + for i in 0..10 { + // say we will write 10 files + let relative_path = format!("part-{i}.parquet"); + let writer = std::io::Cursor::new(vec![]); + let mut writer = FileWriter::new( + writer, + schema.clone(), + WriteOptions { + write_statistics: true, + version: Version::V2, + }, + None, + ); + writer.end(None)?; + let (_, mut metadata) = writer.into_inner_and_metadata(); + + // once done, we write their relative paths: + metadata.row_groups.iter_mut().for_each(|row_group| { + row_group + .columns + .iter_mut() + .for_each(|column| column.file_path = Some(relative_path.clone())) + }); + metadatas.push(metadata); + } + + // merge their row groups + let first = metadatas.pop().unwrap(); + let sidecar = metadatas.into_iter().fold(first, |mut acc, metadata| { + acc.row_groups.extend(metadata.row_groups); + acc + }); + + // and write the metadata on a separate file + let mut writer = std::io::Cursor::new(vec![]); + write_metadata_sidecar(&mut writer, &sidecar)?; + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs index 16459855ea3f..4e24e1d24fa8 100644 --- a/crates/polars/tests/it/lazy/group_by.rs +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -9,7 +9,7 @@ use super::*; #[test] #[cfg(feature = "rank")] fn test_filter_sort_diff_2984() -> PolarsResult<()> { - // make sort that sort doest not oob if filter returns no values + // make sure that sort does not oob if filter returns no values let df = df![ "group"=> ["A" ,"A", "A", "B", "B", "B", "B"], "id"=> [1, 2, 1, 4, 5, 4, 6], diff --git a/crates/polars/tests/it/main.rs b/crates/polars/tests/it/main.rs index 8cf14da210c3..de6fd0d7d33e 100644 --- a/crates/polars/tests/it/main.rs +++ b/crates/polars/tests/it/main.rs @@ -6,4 +6,6 @@ mod lazy; mod schema; mod time; +mod arrow; + pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; diff --git a/crates/polars/tests/it/time/date_range.rs b/crates/polars/tests/it/time/date_range.rs index 9de815368d77..ff8df835cce2 100644 --- a/crates/polars/tests/it/time/date_range.rs +++ b/crates/polars/tests/it/time/date_range.rs @@ -1,6 +1,7 @@ use polars::export::chrono::NaiveDate; use polars::prelude::*; -use polars::time::{date_range, ClosedWindow, Duration}; +#[allow(unused_imports)] +use polars::time::date_range; #[test] fn test_time_units_9413() { diff --git a/docs/data/alltypes_plain.parquet b/docs/data/alltypes_plain.parquet new file mode 100644 index 000000000000..a63f5dca7c38 Binary files /dev/null and b/docs/data/alltypes_plain.parquet differ diff --git a/docs/development/contributing/ci.md b/docs/development/contributing/ci.md index b85330590092..bd771f79a20c 100644 --- a/docs/development/contributing/ci.md +++ b/docs/development/contributing/ci.md @@ -6,11 +6,11 @@ Polars uses GitHub Actions as its continuous integration (CI) tool. The setup is Overall, the CI suite aims to achieve the following: -• Enforce code correctness by running automated tests. -• Enforce code quality by running automated linting checks. -• Enforce code performance by running benchmark tests. -• Enforce that code is properly documented. -• Allow maintainers to easily publish new releases. +- Enforce code correctness by running automated tests. +- Enforce code quality by running automated linting checks. +- Enforce code performance by running benchmark tests. +- Enforce that code is properly documented. +- Allow maintainers to easily publish new releases. We rely on a wide range of tools to achieve this for both the Rust and the Python code base, and thus a lot of checks are triggered on each pull request. @@ -20,9 +20,9 @@ It's entirely possible that you submit a relatively trivial fix that subsequentl The CI setup is designed with the following requirements in mind: -• Get feedback on each step individually. We want to avoid our test job being cancelled because a linting check failed, only to find out later that we also have a failing test. -• Get feedback on each check as quickly as possible. We want to be able to iterate quickly if it turns out our code does not pass some of the checks. -• Only run checks when they need to be run. A change to the Rust code does not warrant a linting check of the Python code, for example. +- Get feedback on each step individually. We want to avoid our test job being cancelled because a linting check failed, only to find out later that we also have a failing test. +- Get feedback on each check as quickly as possible. We want to be able to iterate quickly if it turns out our code does not pass some of the checks. +- Only run checks when they need to be run. A change to the Rust code does not warrant a linting check of the Python code, for example. This results in a modular setup with many separate workflows and jobs that rely heavily on caching. diff --git a/docs/development/contributing/ide.md b/docs/development/contributing/ide.md index a77c83837d64..12bd94cab229 100644 --- a/docs/development/contributing/ide.md +++ b/docs/development/contributing/ide.md @@ -7,9 +7,11 @@ This page contains some recommendations for configuring popular IDEs. Make sure to configure VSCode to use the virtual environment created by the Makefile. -In addition, the extensions below are recommended. +### Extensions -### rust-analyzer +The extensions below are recommended. + +#### rust-analyzer If you work on the Rust code at all, you will need the [rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer) extension. This extension provides code completion for the Rust code. @@ -21,9 +23,9 @@ For it to work well for the Polars code base, add the following settings to your } ``` -### Ruff +#### Ruff -The Ruff extension will help you conform to the formatting requirements of the Python code. +The [Ruff](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) extension will help you conform to the formatting requirements of the Python code. We use both the Ruff linter and formatter. It is recommended to configure the extension to use the Ruff installed in your environment. This will make it use the correct Ruff version and configuration. @@ -34,6 +36,92 @@ This will make it use the correct Ruff version and configuration. } ``` +#### CodeLLDB + +The [CodeLLDB](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) extension is useful for debugging Rust code. +You can also debug Rust code called from Python (see section below). + +### Debugging + +Due to the way that Python and Rust interoperate, debugging the Rust side of development from Python calls can be difficult. +This guide shows how to set up a debugging environment that makes debugging Rust code called from a Python script painless. + +#### Preparation + +Start by installing the CodeLLDB extension (see above). +Then add the following two configurations to your `launch.json` file. +This file is usually found in the `.vscode` folder of your project root. +See the [official VSCode documentation](https://code.visualstudio.com/docs/editor/debugging#_launch-configurations) for more information about the `launch.json` file. + +

launch.json + +```json +{ + "configurations": [ + { + "name": "Debug Rust/Python", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/py-polars/debug/launch.py", + "args": [ + "${file}" + ], + "console": "internalConsole", + "justMyCode": true, + "serverReadyAction": { + "pattern": "pID = ([0-9]+)", + "action": "startDebugging", + "name": "Rust LLDB" + } + }, + { + "name": "Rust LLDB", + "pid": "0", + "type": "lldb", + "request": "attach", + "program": "${workspaceFolder}/py-polars/.venv/bin/python", + "stopOnEntry": false, + "sourceLanguages": [ + "rust" + ], + "presentation": { + "hidden": true + } + } + ] +} +``` + +
+ +!!! info + + On some systems, the LLDB debugger will not attach unless [ptrace protection](https://linux-audit.com/protect-ptrace-processes-kernel-yama-ptrace_scope) is disabled. + To disable, run the following command: + + ```shell + echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope + ``` + +#### Running the debugger + +1. Create a Python script containing Polars code. Ensure that your virtual environment is activated. + +2. Set breakpoints in any `.rs` or `.py` file. + +3. In the `Run and Debug` panel on the left, select `Debug Rust/Python` from the drop-down menu on top and click the `Start Debugging` button. + +At this point, your debugger should stop on breakpoints in any `.rs` file located within the codebase. + +#### Details + +The debugging feature runs via the specially-designed VSCode launch configuration shown above. +The initial Python debugger is launched using a special launch script located at `py-polars/debug/launch.py` and passes the name of the script to be debugged (the target script) as an input argument. +The launch script determines the process ID, writes this value into the `launch.json` configuration file, compiles the target script and runs it in the current environment. +At this point, a second (Rust) debugger is attached to the Python debugger. +The result is two simultaneous debuggers operating on the same running instance. +Breakpoints in the Python code will stop on the Python debugger and breakpoints in the Rust code will stop on the Rust debugger. + ## PyCharm / RustRover / CLion !!! info diff --git a/docs/development/contributing/index.md b/docs/development/contributing/index.md index e78517df33c7..809a149e5160 100644 --- a/docs/development/contributing/index.md +++ b/docs/development/contributing/index.md @@ -128,7 +128,7 @@ When you have resolved your issue, [open a pull request](https://docs.github.com Please adhere to the following guidelines: - Start your pull request title with a [conventional commit](https://www.conventionalcommits.org/) tag. This helps us add your contribution to the right section of the changelog. We use the [Angular convention](https://github.com/angular/angular/blob/22b96b9/CONTRIBUTING.md#type). Scope can be `rust` and/or `python`, depending on your contribution. -- Use a descriptive title. This text will end up in the [changelog](https://github.com/pola-rs/polars/releases). +- Use a descriptive title starting with an uppercase letter. This text will end up in the [changelog](https://github.com/pola-rs/polars/releases). - In the pull request description, [link](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) to the issue you were working on. - Add any relevant information to the description that you think may help the maintainers review your code. - Make sure your branch is [rebased](https://docs.github.com/en/get-started/using-git/about-git-rebase) against the latest version of the `main` branch. @@ -150,9 +150,9 @@ The user guide is maintained in the `docs/user-guide` folder. Before creating a #### Building and serving the user guide -The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for building the user guide by running `make requirements` in the root of the repo. +The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for building the user guide by running `make build` in the root of the repo. -Run `mkdocs serve` to build and serve the user guide, so you can view it locally and see updates as you make changes. +Activate the virtual environment and run `mkdocs serve` to build and serve the user guide, so you can view it locally and see updates as you make changes. #### Creating a new user guide page @@ -227,7 +227,7 @@ From the `py-polars` directory, run `make fmt` to make sure your additions pass Polars uses Sphinx to build the API reference. This means docstrings in general should follow the [reST](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) format. -If you want to build the API reference locally, go to the `py-polars/docs` directory and run `make html SPHINXOPTS=-W`. +If you want to build the API reference locally, go to the `py-polars/docs` directory and run `make html`. The resulting HTML files will be in `py-polars/docs/build/html`. New additions to the API should be added manually to the API reference by adding an entry to the correct `.rst` file in the `py-polars/docs/source/reference` directory. diff --git a/docs/development/contributing/test.md b/docs/development/contributing/test.md index cf102f0fdf5c..67051f1d57cc 100644 --- a/docs/development/contributing/test.md +++ b/docs/development/contributing/test.md @@ -23,14 +23,15 @@ These tests are intended to make sure all Polars functionality works as intended Run unit tests by running `make test` from the `py-polars` folder. This will compile the Rust bindings and then run the unit tests. -If you're working in the Python code only, you can avoid recompiling every time by simply running `pytest` instead. +If you're working in the Python code only, you can avoid recompiling every time by simply running `pytest` instead from your virtual environment. By default, slow tests are skipped. Slow tests are marked as such using a [custom pytest marker](https://docs.pytest.org/en/latest/example/markers.html). If you wish to run slow tests, run `pytest -m slow`. Or run `pytest -m ""` to run _all_ tests, regardless of marker. -Tests can be run in parallel using [`pytext-xdist`](https://pytest-xdist.readthedocs.io/en/latest/). Run `pytest -n auto` to parallelize your test run. +Tests can be run in parallel by running `pytest -n auto`. +The parallelization is handled by [`pytest-xdist`](https://pytest-xdist.readthedocs.io/en/latest/). ### Writing unit tests diff --git a/docs/development/versioning.md b/docs/development/versioning.md index d57bb92fd459..2d8009e8dbe5 100644 --- a/docs/development/versioning.md +++ b/docs/development/versioning.md @@ -80,7 +80,10 @@ Such changes will not be warned for, but _will_ be included in the changelog and ### Deprecation period As a rule, deprecated functionality is removed two breaking releases after the deprecation happens. -For example, a function deprecated in version `0.18.3` will be removed in version `0.20.0`. +For example: + +- Before the release of `1.0.0`: a function deprecated in version `0.18.3` will be removed in version `0.20.0` +- After the release of `1.0.0`: a function deprecated in version `1.2.3` will be removed in version `3.0.0` This means that if your program does not raise any deprecation warnings, it should be mostly safe to upgrade to the next breaking release. As breaking releases happen about once every three months, this allows three to six months to adjust to any pending breaking changes. diff --git a/docs/mlc-config.json b/docs/mlc-config.json index e77aed6c4d0e..e9a807932ac3 100644 --- a/docs/mlc-config.json +++ b/docs/mlc-config.json @@ -2,6 +2,8 @@ "ignorePatterns": [ { "pattern": "^https://crates.io/" + },{ + "pattern": "^https://stackoverflow.com/" } ] } diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 56a8e8e1c04c..5bbaa41487a3 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -1,6 +1,7 @@ [tool.ruff] fix = true +[tool.ruff.lint] ignore = [ "E402", # Module level import not at top of file ] diff --git a/docs/requirements.txt b/docs/requirements.txt index d0a5a5d8193f..dccf92dd62d1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,10 +2,13 @@ pandas pyarrow graphviz matplotlib +seaborn +plotly +altair mkdocs-material==9.5.2 mkdocs-macros-plugin==1.0.4 mkdocs-redirects==1.2.1 material-plausible-plugin==0.2.0 -markdown-exec[ansi]==1.7.0 -PyGithub==2.1.1 +markdown-exec[ansi]==1.8.0 +PyGithub==2.2.0 diff --git a/docs/src/python/user-guide/basics/series-dataframes.py b/docs/src/python/user-guide/concepts/data-structures.py similarity index 62% rename from docs/src/python/user-guide/basics/series-dataframes.py rename to docs/src/python/user-guide/concepts/data-structures.py index 3171da06adbc..edc1a2a25c3c 100644 --- a/docs/src/python/user-guide/basics/series-dataframes.py +++ b/docs/src/python/user-guide/concepts/data-structures.py @@ -5,27 +5,6 @@ print(s) # --8<-- [end:series] -# --8<-- [start:minmax] -s = pl.Series("a", [1, 2, 3, 4, 5]) -print(s.min()) -print(s.max()) -# --8<-- [end:minmax] - -# --8<-- [start:string] -s = pl.Series("a", ["polar", "bear", "arctic", "polar fox", "polar bear"]) -s2 = s.str.replace("polar", "pola") -print(s2) -# --8<-- [end:string] - -# --8<-- [start:dt] -from datetime import date - -start = date(2001, 1, 1) -stop = date(2001, 1, 9) -s = pl.date_range(start, stop, interval="2d", eager=True) -print(s.dt.day()) -# --8<-- [end:dt] - # --8<-- [start:dataframe] from datetime import datetime diff --git a/docs/src/python/user-guide/expressions/null.py b/docs/src/python/user-guide/expressions/missing-data.py similarity index 100% rename from docs/src/python/user-guide/expressions/null.py rename to docs/src/python/user-guide/expressions/missing-data.py diff --git a/docs/src/python/user-guide/basics/expressions.py b/docs/src/python/user-guide/getting-started/expressions.py similarity index 100% rename from docs/src/python/user-guide/basics/expressions.py rename to docs/src/python/user-guide/getting-started/expressions.py diff --git a/docs/src/python/user-guide/basics/joins.py b/docs/src/python/user-guide/getting-started/joins.py similarity index 100% rename from docs/src/python/user-guide/basics/joins.py rename to docs/src/python/user-guide/getting-started/joins.py diff --git a/docs/src/python/user-guide/basics/reading-writing.py b/docs/src/python/user-guide/getting-started/reading-writing.py similarity index 100% rename from docs/src/python/user-guide/basics/reading-writing.py rename to docs/src/python/user-guide/getting-started/reading-writing.py diff --git a/docs/src/python/user-guide/lazy/schema.py b/docs/src/python/user-guide/lazy/schema.py index e621718307ee..5cdf3c657c98 100644 --- a/docs/src/python/user-guide/lazy/schema.py +++ b/docs/src/python/user-guide/lazy/schema.py @@ -9,10 +9,19 @@ print(q3.schema) # --8<-- [end:schema] -# --8<-- [start:typecheck] -pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy().with_columns( - pl.col("bar").round(0) +# --8<-- [start:lazyround] +q4 = ( + pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}) + .lazy() + .with_columns(pl.col("bar").round(0)) ) +# --8<-- [end:lazyround] + +# --8<-- [start:typecheck] +try: + print(q4.collect()) +except Exception as e: + print(e) # --8<-- [end:typecheck] # --8<-- [start:lazyeager] diff --git a/docs/src/python/user-guide/misc/visualization.py b/docs/src/python/user-guide/misc/visualization.py new file mode 100644 index 000000000000..f04288cb7812 --- /dev/null +++ b/docs/src/python/user-guide/misc/visualization.py @@ -0,0 +1,130 @@ +# --8<-- [start:dataframe] +import polars as pl + +path = "docs/data/iris.csv" + +df = pl.scan_csv(path).group_by("species").agg(pl.col("petal_length").mean()).collect() +print(df) +# --8<-- [end:dataframe] + +""" +# --8<-- [start:hvplot_show_plot] +df.plot.bar( + x="species", + y="petal_length", + width=650, +) +# --8<-- [end:hvplot_show_plot] +""" + +# --8<-- [start:hvplot_make_plot] +import hvplot + +plot = df.plot.bar( + x="species", + y="petal_length", + width=650, +) +hvplot.save(plot, "docs/images/hvplot_bar.html") +with open("docs/images/hvplot_bar.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:hvplot_make_plot] + +""" +# --8<-- [start:matplotlib_show_plot] +import matplotlib.pyplot as plt + +plt.bar(x=df["species"], height=df["petal_length"]) +# --8<-- [end:matplotlib_show_plot] +""" + +# --8<-- [start:matplotlib_make_plot] +import base64 + +import matplotlib.pyplot as plt + +plt.bar(x=df["species"], height=df["petal_length"]) +plt.savefig("docs/images/matplotlib_bar.png") +with open("docs/images/matplotlib_bar.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:matplotlib_make_plot] + +""" +# --8<-- [start:seaborn_show_plot] +import seaborn as sns +sns.barplot( + df, + x="species", + y="petal_length", +) +# --8<-- [end:seaborn_show_plot] +""" + +# --8<-- [start:seaborn_make_plot] +import seaborn as sns + +sns.barplot( + df, + x="species", + y="petal_length", +) +plt.savefig("docs/images/seaborn_bar.png") +with open("docs/images/seaborn_bar.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:seaborn_make_plot] + +""" +# --8<-- [start:plotly_show_plot] +import plotly.express as px + +px.bar( + df, + x="species", + y="petal_length", + width=400, +) +# --8<-- [end:plotly_show_plot] +""" + +# --8<-- [start:plotly_make_plot] +import plotly.express as px + +fig = px.bar( + df, + x="species", + y="petal_length", + width=650, +) +fig.write_html("docs/images/plotly_bar.html", full_html=False, include_plotlyjs="cdn") +with open("docs/images/plotly_bar.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:plotly_make_plot] + +""" +# --8<-- [start:altair_show_plot] +import altair as alt + +alt.Chart(df, width=700).mark_bar().encode(x="species:N", y="petal_length:Q") +# --8<-- [end:altair_show_plot] +""" + +# --8<-- [start:altair_make_plot] +import altair as alt + +chart = ( + alt.Chart(df, width=600) + .mark_bar() + .encode( + x="species:N", + y="petal_length:Q", + ) +) +chart.save("docs/images/altair_bar.html") +with open("docs/images/altair_bar.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:altair_make_plot] diff --git a/docs/src/rust/Cargo.toml b/docs/src/rust/Cargo.toml index da1ce364ab03..96e31ebd04b6 100644 --- a/docs/src/rust/Cargo.toml +++ b/docs/src/rust/Cargo.toml @@ -25,16 +25,19 @@ path = "home/example.rs" required-features = ["polars/lazy"] [[bin]] -name = "user-guide-basics-expressions" -path = "user-guide/basics/expressions.rs" +name = "user-guide-getting-started-expressions" +path = "user-guide/getting-started/expressions.rs" required-features = ["polars/lazy"] [[bin]] -name = "user-guide-basics-joins" -path = "user-guide/basics/joins.rs" +name = "user-guide-getting-started-joins" +path = "user-guide/getting-started/joins.rs" [[bin]] -name = "user-guide-basics-reading-writing" -path = "user-guide/basics/reading-writing.rs" +name = "user-guide-getting-started-reading-writing" +path = "user-guide/getting-started/reading-writing.rs" required-features = ["polars/json"] +[[bin]] +name = "user-guide-concepts-data-structures" +path = "user-guide/concepts/data-structures.rs" [[bin]] name = "user-guide-concepts-contexts" @@ -78,8 +81,8 @@ name = "user-guide-expressions-lists" path = "user-guide/expressions/lists.rs" required-features = ["polars/lazy"] [[bin]] -name = "user-guide-expressions-null" -path = "user-guide/expressions/null.rs" +name = "user-guide-expressions-missing-data" +path = "user-guide/expressions/missing-data.rs" required-features = ["polars/lazy"] [[bin]] name = "user-guide-expressions-operators" diff --git a/docs/src/rust/user-guide/concepts/data-structures.rs b/docs/src/rust/user-guide/concepts/data-structures.rs new file mode 100644 index 000000000000..2334f7718569 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/data-structures.rs @@ -0,0 +1,51 @@ +fn main() { + // --8<-- [start:series] + use polars::prelude::*; + + let s = Series::new("a", &[1, 2, 3, 4, 5]); + + println!("{}", s); + // --8<-- [end:series] + + // --8<-- [start:dataframe] + use chrono::NaiveDate; + + let df: DataFrame = df!( + "integer" => &[1, 2, 3, 4, 5], + "date" => &[ + NaiveDate::from_ymd_opt(2025, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2025, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2025, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2025, 1, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2025, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + ], + "float" => &[4.0, 5.0, 6.0, 7.0, 8.0] + ) + .unwrap(); + + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:head] + let df_head = df.head(Some(3)); + + println!("{}", df_head); + // --8<-- [end:head] + + // --8<-- [start:tail] + let df_tail = df.tail(Some(3)); + + println!("{}", df_tail); + // --8<-- [end:tail] + + // --8<-- [start:sample] + let n = Series::new("", &[2]); + let sampled_df = df.sample_n(&n, false, false, None).unwrap(); + + println!("{}", sampled_df); + // --8<-- [end:sample] + + // --8<-- [start:describe] + // Not available in Rust + // --8<-- [end:describe] +} diff --git a/docs/src/rust/user-guide/expressions/casting.rs b/docs/src/rust/user-guide/expressions/casting.rs index 3729ca0492ca..b18ca19022df 100644 --- a/docs/src/rust/user-guide/expressions/casting.rs +++ b/docs/src/rust/user-guide/expressions/casting.rs @@ -1,5 +1,4 @@ // --8<-- [start:setup] -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/expressions/null.rs b/docs/src/rust/user-guide/expressions/missing-data.rs similarity index 100% rename from docs/src/rust/user-guide/expressions/null.rs rename to docs/src/rust/user-guide/expressions/missing-data.rs diff --git a/docs/src/rust/user-guide/expressions/structs.rs b/docs/src/rust/user-guide/expressions/structs.rs index 502f423fdf0d..01c08eaf3d7f 100644 --- a/docs/src/rust/user-guide/expressions/structs.rs +++ b/docs/src/rust/user-guide/expressions/structs.rs @@ -1,5 +1,4 @@ // --8<-- [start:setup] -use polars::lazy::dsl::len; use polars::prelude::*; // --8<-- [end:setup] fn main() -> Result<(), Box> { diff --git a/docs/src/rust/user-guide/basics/expressions.rs b/docs/src/rust/user-guide/getting-started/expressions.rs similarity index 100% rename from docs/src/rust/user-guide/basics/expressions.rs rename to docs/src/rust/user-guide/getting-started/expressions.rs diff --git a/docs/src/rust/user-guide/basics/joins.rs b/docs/src/rust/user-guide/getting-started/joins.rs similarity index 100% rename from docs/src/rust/user-guide/basics/joins.rs rename to docs/src/rust/user-guide/getting-started/joins.rs diff --git a/docs/src/rust/user-guide/basics/reading-writing.rs b/docs/src/rust/user-guide/getting-started/reading-writing.rs similarity index 96% rename from docs/src/rust/user-guide/basics/reading-writing.rs rename to docs/src/rust/user-guide/getting-started/reading-writing.rs index dad5e8713d24..bc021e9a21de 100644 --- a/docs/src/rust/user-guide/basics/reading-writing.rs +++ b/docs/src/rust/user-guide/getting-started/reading-writing.rs @@ -13,7 +13,8 @@ fn main() -> Result<(), Box> { NaiveDate::from_ymd_opt(2025, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2025, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), ], - "float" => &[4.0, 5.0, 6.0] + "float" => &[4.0, 5.0, 6.0], + "string" => &["a", "b", "c"], ) .unwrap(); println!("{}", df); diff --git a/docs/src/rust/user-guide/transformations/pivot.rs b/docs/src/rust/user-guide/transformations/pivot.rs index 2115b528579c..804ead13f056 100644 --- a/docs/src/rust/user-guide/transformations/pivot.rs +++ b/docs/src/rust/user-guide/transformations/pivot.rs @@ -7,20 +7,29 @@ fn main() -> Result<(), Box> { // --8<-- [start:df] let df = df!( "foo"=> ["A", "A", "B", "B", "C"], - "N"=> [1, 2, 2, 4, 2], "bar"=> ["k", "l", "m", "n", "o"], + "N"=> [1, 2, 2, 4, 2], )?; println!("{}", &df); // --8<-- [end:df] // --8<-- [start:eager] - let out = pivot(&df, ["N"], ["foo"], ["bar"], false, None, None)?; + let out = pivot(&df, ["foo"], ["bar"], Some(["N"]), false, None, None)?; println!("{}", &out); // --8<-- [end:eager] // --8<-- [start:lazy] let q = df.lazy(); - let q2 = pivot(&q.collect()?, ["N"], ["foo"], ["bar"], false, None, None)?.lazy(); + let q2 = pivot( + &q.collect()?, + ["foo"], + ["bar"], + Some(["N"]), + false, + None, + None, + )? + .lazy(); let out = q2.collect()?; println!("{}", &out); // --8<-- [end:lazy] diff --git a/docs/src/rust/user-guide/transformations/time-series/filter.rs b/docs/src/rust/user-guide/transformations/time-series/filter.rs index 56c6589c1555..06ce39eb0c5f 100644 --- a/docs/src/rust/user-guide/transformations/time-series/filter.rs +++ b/docs/src/rust/user-guide/transformations/time-series/filter.rs @@ -1,7 +1,6 @@ // --8<-- [start:setup] use chrono::prelude::*; use polars::io::prelude::*; -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/src/rust/user-guide/transformations/time-series/parsing.rs index b35a52215741..3462943d15af 100644 --- a/docs/src/rust/user-guide/transformations/time-series/parsing.rs +++ b/docs/src/rust/user-guide/transformations/time-series/parsing.rs @@ -1,6 +1,5 @@ // --8<-- [start:setup] use polars::io::prelude::*; -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs index fc81f34412bb..c9b7e58906cc 100644 --- a/docs/src/rust/user-guide/transformations/time-series/rolling.rs +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -1,7 +1,6 @@ // --8<-- [start:setup] use chrono::prelude::*; use polars::io::prelude::*; -use polars::lazy::dsl::GetOutput; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/user-guide/concepts/data-structures.md b/docs/user-guide/concepts/data-structures.md index 0cba7289ad02..860ac9da99bb 100644 --- a/docs/user-guide/concepts/data-structures.md +++ b/docs/user-guide/concepts/data-structures.md @@ -7,20 +7,20 @@ The core base data structures provided by Polars are `Series` and `DataFrame`. Series are a 1-dimensional data structure. Within a series all elements have the same [Data Type](data-types/overview.md) . The snippet below shows how to create a simple named `Series` object. -{{code_block('getting-started/series-dataframes','series',['Series'])}} +{{code_block('user-guide/concepts/data-structures','series',['Series'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:series" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:series" ``` ## DataFrame A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it can be seen as an abstraction of a collection (e.g. list) of `Series`. Operations that can be executed on a `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. -{{code_block('getting-started/series-dataframes','dataframe',['DataFrame'])}} +{{code_block('user-guide/concepts/data-structures','dataframe',['DataFrame'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:dataframe" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:dataframe" ``` ### Viewing data @@ -31,38 +31,38 @@ This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). -{{code_block('getting-started/series-dataframes','head',['head'])}} +{{code_block('user-guide/concepts/data-structures','head',['head'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:head" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:head" ``` #### Tail The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. -{{code_block('getting-started/series-dataframes','tail',['tail'])}} +{{code_block('user-guide/concepts/data-structures','tail',['tail'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:tail" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:tail" ``` #### Sample If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. -{{code_block('getting-started/series-dataframes','sample',['sample'])}} +{{code_block('user-guide/concepts/data-structures','sample',['sample'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:sample" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:sample" ``` #### Describe `Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. -{{code_block('getting-started/series-dataframes','describe',['describe'])}} +{{code_block('user-guide/concepts/data-structures','describe',['describe'])}} -```python exec="on" result="text" session="getting-started/series" ---8<-- "python/user-guide/basics/series-dataframes.py:describe" +```python exec="on" result="text" session="user-guide/data-structures" +--8<-- "python/user-guide/concepts/data-structures.py:describe" ``` diff --git a/docs/user-guide/concepts/data-types/overview.md b/docs/user-guide/concepts/data-types/overview.md index 30e7073bccc5..86c705605031 100644 --- a/docs/user-guide/concepts/data-types/overview.md +++ b/docs/user-guide/concepts/data-types/overview.md @@ -16,7 +16,7 @@ from Arrow, with the exception of `String` (this is actually `LargeUtf8`), `Cate | | `UInt64` | 64-bit unsigned integer. | | | `Float32` | 32-bit floating point. | | | `Float64` | 64-bit floating point. | -| Nested | `Struct` | A struct array is represented as a `Vec` and is useful to pack multiple/heterogenous values in a single column. | +| Nested | `Struct` | A struct array is represented as a `Vec` and is useful to pack multiple/heterogeneous values in a single column. | | | `List` | A list array contains a child array containing the list values and an offset array. (this is actually Arrow `LargeList` internally). | | Temporal | `Date` | Date representation, internally represented as days since UNIX epoch encoded by a 32-bit signed integer. | | | `Datetime` | Datetime representation, internally represented as microseconds since UNIX epoch encoded by a 64-bit signed integer. | @@ -41,6 +41,6 @@ Polars generally follows the IEEE 754 floating point standard for `Float32` and e.g. a sort or group by operation may canonicalize all zeroes to +0 and all NaNs to a positive NaN without payload for efficient equality checks. -Polars always attempts to provide reasonably accurate results for floating point computations, but does not provide guarantees +Polars always attempts to provide reasonably accurate results for floating point computations but does not provide guarantees on the error unless mentioned otherwise. Generally speaking 100% accurate results are infeasibly expensive to acquire (requiring much larger internal representations than 64-bit floats), and thus some error is always to be expected. diff --git a/docs/user-guide/concepts/lazy-vs-eager.md b/docs/user-guide/concepts/lazy-vs-eager.md index 3e2c54c2e39f..4822f81a5d1d 100644 --- a/docs/user-guide/concepts/lazy-vs-eager.md +++ b/docs/user-guide/concepts/lazy-vs-eager.md @@ -1,6 +1,6 @@ # Lazy / eager API -Polars supports two modes of operation: lazy and eager. In the eager API the query is executed immediately while in the lazy API the query is only evaluated once it is 'needed'. Deferring the execution to the last minute can have significant performance advantages that is why the Lazy API is preferred in most cases. Let us demonstrate this with an example: +Polars supports two modes of operation: lazy and eager. In the eager API the query is executed immediately while in the lazy API the query is only evaluated once it is 'needed'. Deferring the execution to the last minute can have significant performance advantages and is why the Lazy API is preferred in most cases. Let us demonstrate this with an example: {{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} diff --git a/docs/user-guide/ecosystem.md b/docs/user-guide/ecosystem.md index f745cdefa41f..31fb44595e37 100644 --- a/docs/user-guide/ecosystem.md +++ b/docs/user-guide/ecosystem.md @@ -1,4 +1,4 @@ -# Integrations and ecosystem +# Ecosystem ## Introduction @@ -12,27 +12,29 @@ On this page you can find a non-exhaustive list of libraries and tools that supp - [Machine learning](#machine-learning) - [Other](#other) -#### [Apache Arrow](https://arrow.apache.org/) +--- -Apache Arrow enables zero-copy reads of data within the same process, meaning that data can be directly accessed in its in-memory format without the need for copying or serialisation. This enhances performance when integrating with different tools using Apache Arrow. Polars is compatible with a wide range of libraries that also make use of Apache Arrow, like Pandas and DuckDB. +### Apache Arrow + +[Apache Arrow](https://arrow.apache.org/) enables zero-copy reads of data within the same process, meaning that data can be directly accessed in its in-memory format without the need for copying or serialisation. This enhances performance when integrating with different tools using Apache Arrow. Polars is compatible with a wide range of libraries that also make use of Apache Arrow, like Pandas and DuckDB. ### Data visualisation -#### [hvPlot](https://hvplot.holoviz.org/) +#### hvPlot -hvPlot is available as the default plotting backend for Polars making it simple to create interactive and static visualisations. You can use hvPlot by using the feature flag `plot` during installing. +[hvPlot](https://hvplot.holoviz.org/) is available as the default plotting backend for Polars making it simple to create interactive and static visualisations. You can use hvPlot by using the feature flag `plot` during installing. ```python pip install 'polars[plot]' ``` -#### [Matplotlib](https://matplotlib.org/) +#### Matplotlib -Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible. +[Matplotlib](https://matplotlib.org/) is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible. -#### [Plotly Dash](https://github.com/plotly/dash) +#### Plotly -Dash is the original low-code framework for rapidly building data apps in Python. Learn more about how to build fast Dash apps at [Plotly.com](https://plotly.com/blog/polars-to-build-fast-dash-apps-for-large-datasets/). +[Plotly](https://plotly.com/python/) is an interactive, open-source, and browser-based graphing library for Python. Built on top of plotly.js, it ships with over 30 chart types, including scientific charts, 3D graphs, statistical charts, SVG maps, financial charts, and more. #### [Seaborn](https://seaborn.pydata.org/) @@ -40,28 +42,32 @@ Seaborn is a Python data visualization library based on Matplotlib. It provides ### IO -#### [Delta Lake](https://github.com/delta-io/delta-rs) +#### Delta Lake -The Delta Lake project aims to unlock the power of the Deltalake for as many users and projects as possible by providing native low-level APIs aimed at developers and integrators, as well as a high-level operations API that lets you query, inspect, and operate your Delta Lake with ease. +The [Delta Lake](https://github.com/delta-io/delta-rs) project aims to unlock the power of the Deltalake for as many users and projects as possible by providing native low-level APIs aimed at developers and integrators, as well as a high-level operations API that lets you query, inspect, and operate your Delta Lake with ease. Read how to use Delta Lake with Polars [at Delta Lake](https://delta-io.github.io/delta-rs/integrations/delta-lake-polars/#reading-a-delta-lake-table-with-polars). ### Machine Learning -#### [Scikit Learn](https://scikit-learn.org/stable/) +#### Scikit Learn -Since Scikit Learn 1.4, all transformers support Polars output. See the change log for [more details](https://scikit-learn.org/dev/whats_new/v1.4.html#changes-impacting-all-modules). +Since [Scikit Learn](https://scikit-learn.org/stable/) 1.4, all transformers support Polars output. See the change log for [more details](https://scikit-learn.org/dev/whats_new/v1.4.html#changes-impacting-all-modules). ### Other -#### [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html) +#### DuckDB + +[DuckDB](https://duckdb.org) is a high-performance analytical database system. It is designed to be fast, reliable, portable, and easy to use. DuckDB provides a rich SQL dialect, with support far beyond basic SQL. DuckDB supports arbitrary and nested correlated subqueries, window functions, collations, complex types (arrays, structs), and more. Read about integration with Polars [on the DuckDB website](https://duckdb.org/docs/guides/python/polars). + +#### Great Tables -With Great Tables anyone can make wonderful-looking tables in Python. Here is a [blog post](https://posit-dev.github.io/great-tables/blog/polars-styling/) on how to use Great Tables with Polars. +With [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html) anyone can make wonderful-looking tables in Python. Here is a [blog post](https://posit-dev.github.io/great-tables/blog/polars-styling/) on how to use Great Tables with Polars. -#### [LanceDB](https://lancedb.com/) +#### LanceDB -LanceDB is a developer-friendly, serverless vector database for AI applications. They have added a direct integration with Polars. LanceDB can ingest Polars dataframes, return results as polars dataframes, and export the entire table as a polars lazyframe. You can find a quick tutorial in their blog [LanceDB + Polars](https://blog.lancedb.com/lancedb-polars-2d5eb32a8aa3) +[LanceDB](https://lancedb.com/) is a developer-friendly, serverless vector database for AI applications. They have added a direct integration with Polars. LanceDB can ingest Polars dataframes, return results as polars dataframes, and export the entire table as a polars lazyframe. You can find a quick tutorial in their blog [LanceDB + Polars](https://blog.lancedb.com/lancedb-polars-2d5eb32a8aa3) -#### [Mage](https://www.mage.ai) +#### Mage -Open-source data pipeline tool for transforming and integrating data. Learn about integration between Polars and Mage at [docs.mage.ai](https://docs.mage.ai/integrations/polars). +[Mage](https://www.mage.ai) is an open-source data pipeline tool for transforming and integrating data. Learn about integration between Polars and Mage at [docs.mage.ai](https://docs.mage.ai/integrations/polars). diff --git a/docs/user-guide/expressions/index.md b/docs/user-guide/expressions/index.md index 3724e09ce15e..32550974782e 100644 --- a/docs/user-guide/expressions/index.md +++ b/docs/user-guide/expressions/index.md @@ -8,7 +8,7 @@ In the `Contexts` sections we outlined what `Expressions` are and how they are i - [Casting](casting.md) - [Strings](strings.md) - [Aggregation](aggregation.md) -- [Null](null.md) +- [Missing data](missing-data.md) - [Window](window.md) - [Folds](folds.md) - [Lists](lists.md) diff --git a/docs/user-guide/expressions/null.md b/docs/user-guide/expressions/missing-data.md similarity index 68% rename from docs/user-guide/expressions/null.md rename to docs/user-guide/expressions/missing-data.md index 8092a7187cdd..8b95efabe847 100644 --- a/docs/user-guide/expressions/null.md +++ b/docs/user-guide/expressions/missing-data.md @@ -10,11 +10,11 @@ Polars also allows `NotaNumber` or `NaN` values for float columns. These `NaN` v You can manually define a missing value with the python `None` value: -{{code_block('user-guide/expressions/null','dataframe',['DataFrame'])}} +{{code_block('user-guide/expressions/missing-data','dataframe',['DataFrame'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:setup" ---8<-- "python/user-guide/expressions/null.py:dataframe" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:setup" +--8<-- "python/user-guide/expressions/missing-data.py:dataframe" ``` !!! info @@ -27,10 +27,10 @@ Each Arrow array used by Polars stores two kinds of metadata related to missing The first piece of metadata is the `null_count` - this is the number of rows with `null` values in the column: -{{code_block('user-guide/expressions/null','count',['null_count'])}} +{{code_block('user-guide/expressions/missing-data','count',['null_count'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:count" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:count" ``` The `null_count` method can be called on a `DataFrame`, a column from a `DataFrame` or a `Series`. The `null_count` method is a cheap operation as `null_count` is already calculated for the underlying Arrow array. @@ -40,10 +40,10 @@ The validity bitmap is memory efficient as it is bit encoded - each value is eit You can return a `Series` based on the validity bitmap for a column in a `DataFrame` or a `Series` with the `is_null` method: -{{code_block('user-guide/expressions/null','isnull',['is_null'])}} +{{code_block('user-guide/expressions/missing-data','isnull',['is_null'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:isnull" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:isnull" ``` The `is_null` method is a cheap operation that does not require scanning the full column for `null` values. This is because the validity bitmap already exists and can be returned as a Boolean array. @@ -59,30 +59,30 @@ Missing data in a `Series` can be filled with the `fill_null` method. You have t We illustrate each way to fill nulls by defining a simple `DataFrame` with a missing value in `col2`: -{{code_block('user-guide/expressions/null','dataframe2',['DataFrame'])}} +{{code_block('user-guide/expressions/missing-data','dataframe2',['DataFrame'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:dataframe2" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:dataframe2" ``` ### Fill with specified literal value We can fill the missing data with a specified literal value with `pl.lit`: -{{code_block('user-guide/expressions/null','fill',['fill_null'])}} +{{code_block('user-guide/expressions/missing-data','fill',['fill_null'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:fill" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fill" ``` ### Fill with a strategy We can fill the missing data with a strategy such as filling forward: -{{code_block('user-guide/expressions/null','fillstrategy',['fill_null'])}} +{{code_block('user-guide/expressions/missing-data','fillstrategy',['fill_null'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:fillstrategy" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillstrategy" ``` You can find other fill strategies in the API docs. @@ -92,10 +92,10 @@ You can find other fill strategies in the API docs. For more flexibility we can fill the missing data with an expression. For example, to fill nulls with the median value from that column: -{{code_block('user-guide/expressions/null','fillexpr',['fill_null'])}} +{{code_block('user-guide/expressions/missing-data','fillexpr',['fill_null'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:fillexpr" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillexpr" ``` In this case the column is cast from integer to float because the median is a float statistic. @@ -104,20 +104,20 @@ In this case the column is cast from integer to float because the median is a fl In addition, we can fill nulls with interpolation (without using the `fill_null` function): -{{code_block('user-guide/expressions/null','fillinterpolate',['interpolate'])}} +{{code_block('user-guide/expressions/missing-data','fillinterpolate',['interpolate'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:fillinterpolate" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillinterpolate" ``` ## `NotaNumber` or `NaN` values Missing data in a `Series` has a `null` value. However, you can use `NotaNumber` or `NaN` values in columns with float datatypes. These `NaN` values can be created from Numpy's `np.nan` or the native python `float('nan')`: -{{code_block('user-guide/expressions/null','nan',['DataFrame'])}} +{{code_block('user-guide/expressions/missing-data','nan',['DataFrame'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:nan" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:nan" ``` !!! info @@ -133,8 +133,8 @@ Polars has `is_nan` and `fill_nan` methods which work in a similar way to the `i One further difference between `null` and `NaN` values is that taking the `mean` of a column with `null` values excludes the `null` values from the calculation but with `NaN` values taking the mean results in a `NaN`. This behaviour can be avoided by replacing the `NaN` values with `null` values; -{{code_block('user-guide/expressions/null','nanfill',['fill_nan'])}} +{{code_block('user-guide/expressions/missing-data','nanfill',['fill_nan'])}} -```python exec="on" result="text" session="user-guide/null" ---8<-- "python/user-guide/expressions/null.py:nanfill" +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:nanfill" ``` diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md index 8bfab6691c1e..c427ace6dac6 100644 --- a/docs/user-guide/expressions/plugins.md +++ b/docs/user-guide/expressions/plugins.md @@ -37,7 +37,7 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "*" } -pyo3 = { version = "*", features = ["extension-module"] } +pyo3 = { version = "*", features = ["extension-module", "abi-py38"] } pyo3-polars = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] } ``` @@ -96,20 +96,44 @@ import polars as pl from polars.type_aliases import IntoExpr from polars.utils.udfs import _get_shared_lib_location +from expression_lib.utils import parse_into_expr + # Boilerplate needed to inform Polars of the location of binary wheel. lib = _get_shared_lib_location(__file__) -@pl.api.register_expr_namespace("language") -class Language: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def pig_latinnify(self) -> pl.Expr: - return self._expr._register_plugin( - lib=lib, - symbol="pig_latinnify", - is_elementwise=True, - ) +def pig_latinnify(expr: IntoExpr, capitalize: bool = False) -> pl.Expr: + expr = parse_into_expr(expr) + return expr.register_plugin( + lib=lib, + symbol="pig_latinnify", + is_elementwise=True, + ) +``` + +```python +# expression_lib/utils.py +import polars as pl + +from polars.type_aliases import IntoExpr, PolarsDataType + + +def parse_into_expr( + expr: IntoExpr, + *, + str_as_lit: bool = False, + list_as_lit: bool = True, + dtype: PolarsDataType | None = None, +) -> pl.Expr: + """Parse a single input into an expression.""" + if isinstance(expr, pl.Expr): + pass + elif isinstance(expr, str) and not str_as_lit: + expr = pl.col(expr) + elif isinstance(expr, list) and not list_as_lit: + expr = pl.lit(pl.Series(expr), dtype=dtype) + else: + expr = pl.lit(expr, dtype=dtype) + return expr ``` We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`. @@ -118,15 +142,19 @@ And that's it. Our expression is ready to use! ```python import polars as pl -from expression_lib import Language +from expression_lib import pig_latinnify df = pl.DataFrame( { "convert": ["pig", "latin", "is", "silly"], } ) +out = df.with_columns(pig_latin=pig_latinnify("convert")) +``` +Alternatively, you can [register a custom namespace](https://docs.pola.rs/py-polars/html/reference/api/polars.api.register_expr_namespace.html#polars.api.register_expr_namespace), which enables you to write: +```python out = df.with_columns( pig_latin=pl.col("convert").language.pig_latinnify(), ) @@ -173,38 +201,34 @@ fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { On the Python side the kwargs can be passed when we register the plugin. ```python -@pl.api.register_expr_namespace("my_expr") -class MyCustomExpr: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def append_args( - self, - float_arg: float, - integer_arg: int, - string_arg: str, - boolean_arg: bool, - ) -> pl.Expr: - """ - This example shows how arguments other than `Series` can be used. - """ - return self._expr._register_plugin( - lib=lib, - args=[], - kwargs={ - "float_arg": float_arg, - "integer_arg": integer_arg, - "string_arg": string_arg, - "boolean_arg": boolean_arg, - }, - symbol="append_kwargs", - is_elementwise=True, - ) +def append_args( + expr: IntoExpr, + float_arg: float, + integer_arg: int, + string_arg: str, + boolean_arg: bool, +) -> pl.Expr: + """ + This example shows how arguments other than `Series` can be used. + """ + expr = parse_into_expr(expr) + return expr.register_plugin( + lib=lib, + args=[], + kwargs={ + "float_arg": float_arg, + "integer_arg": integer_arg, + "string_arg": string_arg, + "boolean_arg": boolean_arg, + }, + symbol="append_kwargs", + is_elementwise=True, + ) ``` ## Output data types -Output data types ofcourse don't have to be fixed. They often depend on the input types of an expression. To accommodate +Output data types of course don't have to be fixed. They often depend on the input types of an expression. To accommodate this you can provide the `#[polars_expr()]` macro with an `output_type_func` argument that points to a function. This function can map input fields `&[Field]` to an output `Field` (name and data type). @@ -253,3 +277,11 @@ Here is a curated (non-exhaustive) list of community implemented plugins. - [polars-distance](https://github.com/ion-elgreco/polars-distance) Polars plugin for pairwise distance functions - [polars-ds](https://github.com/abstractqqq/polars_ds_extension) Polars extension aiming to simplify common numerical/string data analysis procedures - [polars-hash](https://github.com/ion-elgreco/polars-hash) Stable non-cryptographic and cryptographic hashing functions for Polars +- [polars-reverse-geocode](https://github.com/MarcoGorelli/polars-reverse-geocode) Offline reverse geocoder for finding the closest city + to a given (latitude, longitude) pair + +## Other community material + +- [Polars plugins tutorial](https://marcogorelli.github.io/polars-plugins-tutorial/) Learn how to write a plugin by + going through some very simple and minimal examples +- [cookiecutter-polars-plugin](https://github.com/MarcoGorelli/cookiecutter-polars-plugins) Project template for Polars Plugins diff --git a/docs/user-guide/expressions/structs.md b/docs/user-guide/expressions/structs.md index 61978bbc25e7..056c1b2e21b7 100644 --- a/docs/user-guide/expressions/structs.md +++ b/docs/user-guide/expressions/structs.md @@ -31,7 +31,7 @@ Quite unexpected an output, especially if coming from tools that do not have suc !!! note "Why `value_counts` returns a `Struct`" - Polars expressions always have a `Fn(Series) -> Series` signature and `Struct` is thus the data type that allows us to provide multiple columns as input/ouput of an expression. In other words, all expressions have to return a `Series` object, and `Struct` allows us to stay consistent with that requirement. + Polars expressions always have a `Fn(Series) -> Series` signature and `Struct` is thus the data type that allows us to provide multiple columns as input/output of an expression. In other words, all expressions have to return a `Series` object, and `Struct` allows us to stay consistent with that requirement. ## Structs as `dict`s diff --git a/docs/user-guide/getting-started.md b/docs/user-guide/getting-started.md index 13d732241c21..2a601597bb3d 100644 --- a/docs/user-guide/getting-started.md +++ b/docs/user-guide/getting-started.md @@ -24,21 +24,21 @@ This chapter is here to help you get started with Polars. It covers all the fund Polars supports reading and writing for common file formats (e.g. csv, json, parquet), cloud storage (S3, Azure Blob, BigQuery) and databases (e.g. postgres, mysql). Below we show the concept of reading and writing to disk. -{{code_block('user-guide/basics/reading-writing','dataframe',['DataFrame'])}} +{{code_block('user-guide/getting-started/reading-writing','dataframe',['DataFrame'])}} ```python exec="on" result="text" session="getting-started/reading" ---8<-- "python/user-guide/basics/reading-writing.py:dataframe" +--8<-- "python/user-guide/getting-started/reading-writing.py:dataframe" ``` In the example below we write the DataFrame to a csv file called `output.csv`. After that, we read it back using `read_csv` and then `print` the result for inspection. -{{code_block('user-guide/basics/reading-writing','csv',['read_csv','write_csv'])}} +{{code_block('user-guide/getting-started/reading-writing','csv',['read_csv','write_csv'])}} ```python exec="on" result="text" session="getting-started/reading" ---8<-- "python/user-guide/basics/reading-writing.py:csv" +--8<-- "python/user-guide/getting-started/reading-writing.py:csv" ``` -For more examples on the CSV file format and other data formats, start with the [IO section](io/index.md) of the User Guide. +For more examples on the CSV file format and other data formats, start with the [IO section](io/index.md) of the user guide. ## Expressions @@ -49,7 +49,7 @@ For more examples on the CSV file format and other data formats, start with the - `with_columns` - `group_by` -To learn more about expressions and the context in which they operate, see the User Guide sections: [Contexts](concepts/contexts.md) and [Expressions](concepts/expressions.md). +To learn more about expressions and the context in which they operate, see the user guide sections: [Contexts](concepts/contexts.md) and [Expressions](concepts/expressions.md). ### Select @@ -60,46 +60,46 @@ To select a column we need to do two things: In the example below you see that we select `col('*')`. The asterisk stands for all columns. -{{code_block('user-guide/basics/expressions','select',['select'])}} +{{code_block('user-guide/getting-started/expressions','select',['select'])}} ```python exec="on" result="text" session="getting-started/expressions" ---8<-- "python/user-guide/basics/expressions.py:setup" +--8<-- "python/user-guide/getting-started/expressions.py:setup" print( - --8<-- "python/user-guide/basics/expressions.py:select" + --8<-- "python/user-guide/getting-started/expressions.py:select" ) ``` You can also specify the specific columns that you want to return. There are two ways to do this. The first option is to pass the column names, as seen below. -{{code_block('user-guide/basics/expressions','select2',['select'])}} +{{code_block('user-guide/getting-started/expressions','select2',['select'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:select2" + --8<-- "python/user-guide/getting-started/expressions.py:select2" ) ``` -Follow these links to other parts of the User guide to learn more about [basic operations](expressions/operators.md) or [column selections](expressions/column-selections.md). +Follow these links to other parts of the user guide to learn more about [basic operations](expressions/operators.md) or [column selections](expressions/column-selections.md). ### Filter The `filter` option allows us to create a subset of the `DataFrame`. We use the same `DataFrame` as earlier and we filter between two specified dates. -{{code_block('user-guide/basics/expressions','filter',['filter'])}} +{{code_block('user-guide/getting-started/expressions','filter',['filter'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:filter" + --8<-- "python/user-guide/getting-started/expressions.py:filter" ) ``` With `filter` you can also create more complex filters that include multiple columns. -{{code_block('user-guide/basics/expressions','filter2',['filter'])}} +{{code_block('user-guide/getting-started/expressions','filter2',['filter'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:filter2" + --8<-- "python/user-guide/getting-started/expressions.py:filter2" ) ``` @@ -107,38 +107,38 @@ print( `with_columns` allows you to create new columns for your analyses. We create two new columns `e` and `b+42`. First we sum all values from column `b` and store the results in column `e`. After that we add `42` to the values of `b`. Creating a new column `b+42` to store these results. -{{code_block('user-guide/basics/expressions','with_columns',['with_columns'])}} +{{code_block('user-guide/getting-started/expressions','with_columns',['with_columns'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:with_columns" + --8<-- "python/user-guide/getting-started/expressions.py:with_columns" ) ``` -### Group_by +### Group by We will create a new `DataFrame` for the Group by functionality. This new `DataFrame` will include several 'groups' that we want to group by. -{{code_block('user-guide/basics/expressions','dataframe2',['DataFrame'])}} +{{code_block('user-guide/getting-started/expressions','dataframe2',['DataFrame'])}} ```python exec="on" result="text" session="getting-started/expressions" ---8<-- "python/user-guide/basics/expressions.py:dataframe2" +--8<-- "python/user-guide/getting-started/expressions.py:dataframe2" print(df2) ``` -{{code_block('user-guide/basics/expressions','group_by',['group_by'])}} +{{code_block('user-guide/getting-started/expressions','group_by',['group_by'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:group_by" + --8<-- "python/user-guide/getting-started/expressions.py:group_by" ) ``` -{{code_block('user-guide/basics/expressions','group_by2',['group_by'])}} +{{code_block('user-guide/getting-started/expressions','group_by2',['group_by'])}} ```python exec="on" result="text" session="getting-started/expressions" print( - --8<-- "python/user-guide/basics/expressions.py:group_by2" + --8<-- "python/user-guide/getting-started/expressions.py:group_by2" ) ``` @@ -146,16 +146,16 @@ print( Below are some examples on how to combine operations to create the `DataFrame` you require. -{{code_block('user-guide/basics/expressions','combine',['select','with_columns'])}} +{{code_block('user-guide/getting-started/expressions','combine',['select','with_columns'])}} ```python exec="on" result="text" session="getting-started/expressions" ---8<-- "python/user-guide/basics/expressions.py:combine" +--8<-- "python/user-guide/getting-started/expressions.py:combine" ``` -{{code_block('user-guide/basics/expressions','combine2',['select','with_columns'])}} +{{code_block('user-guide/getting-started/expressions','combine2',['select','with_columns'])}} ```python exec="on" result="text" session="getting-started/expressions" ---8<-- "python/user-guide/basics/expressions.py:combine2" +--8<-- "python/user-guide/getting-started/expressions.py:combine2" ``` ## Combining DataFrames @@ -166,11 +166,11 @@ There are two ways `DataFrame`s can be combined depending on the use case: join Polars supports all types of join (e.g. left, right, inner, outer). Let's have a closer look on how to `join` two `DataFrames` into a single `DataFrame`. Our two `DataFrames` both have an 'id'-like column: `a` and `x`. We can use those columns to `join` the `DataFrames` in this example. -{{code_block('user-guide/basics/joins','join',['join'])}} +{{code_block('user-guide/getting-started/joins','join',['join'])}} ```python exec="on" result="text" session="getting-started/joins" ---8<-- "python/user-guide/basics/joins.py:setup" ---8<-- "python/user-guide/basics/joins.py:join" +--8<-- "python/user-guide/getting-started/joins.py:setup" +--8<-- "python/user-guide/getting-started/joins.py:join" ``` To see more examples with other types of joins, see the [Transformations section](transformations/joins.md) in the user guide. @@ -179,8 +179,8 @@ To see more examples with other types of joins, see the [Transformations section We can also `concatenate` two `DataFrames`. Vertical concatenation will make the `DataFrame` longer. Horizontal concatenation will make the `DataFrame` wider. Below you can see the result of an horizontal concatenation of our two `DataFrames`. -{{code_block('user-guide/basics/joins','hstack',['hstack'])}} +{{code_block('user-guide/getting-started/joins','hstack',['hstack'])}} ```python exec="on" result="text" session="getting-started/joins" ---8<-- "python/user-guide/basics/joins.py:hstack" +--8<-- "python/user-guide/getting-started/joins.py:hstack" ``` diff --git a/docs/user-guide/installation.md b/docs/user-guide/installation.md index fa777272febf..83aa684bf92e 100644 --- a/docs/user-guide/installation.md +++ b/docs/user-guide/installation.md @@ -123,12 +123,11 @@ The opt-in features are: - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. And activates `pivot` and `transpose` operations - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. - - `cross_join` - Create the cartesian product of two DataFrames. + - `cross_join` - Create the Cartesian product of two DataFrames. - `semi_anti_join` - SEMI and ANTI joins. - `group_by_list` - Allow group by operation on keys of type List. - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked - `diagonal_concat` - Concat diagonally thereby combining different schemas. - - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series) - `partition_by` - Split into multiple DataFrames partitioned by groups. - `Series`/`Expression` operations: diff --git a/docs/user-guide/lazy/schemas.md b/docs/user-guide/lazy/schemas.md index 77d2be54b722..6bb6706e86e5 100644 --- a/docs/user-guide/lazy/schemas.md +++ b/docs/user-guide/lazy/schemas.md @@ -17,11 +17,16 @@ One advantage of the lazy API is that Polars will check the schema before any da We see how this works in the following simple example where we call the `.round` expression on the integer `bar` column. -{{code_block('user-guide/lazy/schema','typecheck',['lazy','with_columns'])}} +{{code_block('user-guide/lazy/schema','lazyround',['lazy','with_columns'])}} The `.round` expression is only valid for columns with a floating point dtype. Calling `.round` on an integer column means the operation will raise an `InvalidOperationError` when we evaluate the query with `collect`. This schema check happens before the data is processed when we call `collect`. -`python exec="on" result="text" session="user-guide/lazy/schemas"` +{{code_block('user-guide/lazy/schema','typecheck',[])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:lazyround" +--8<-- "python/user-guide/lazy/schema.py:typecheck" +``` If we executed this query in eager mode the error would only be found once the data had been processed in all earlier steps. diff --git a/docs/user-guide/migration/pandas.md b/docs/user-guide/migration/pandas.md index dc57354c43ab..164cfd389176 100644 --- a/docs/user-guide/migration/pandas.md +++ b/docs/user-guide/migration/pandas.md @@ -314,7 +314,7 @@ For float columns Polars permits the use of `NaN` values. These `NaN` values are In pandas an integer column with missing values is cast to be a float column with `NaN` values for the missing values (unless using optional nullable integer dtypes). In Polars any missing values in an integer column are simply `null` values and the column remains an integer column. -See the [missing data](../expressions/null.md) section for more details. +See the [missing data](../expressions/missing-data.md) section for more details. ## Pipe littering @@ -344,7 +344,7 @@ def add_ham(df: pd.DataFrame) -> pd.DataFrame: ) ``` -If we do this in polars, we would create 3 `with_column` contexts, that forces Polars to run the 3 pipes sequentially, +If we do this in polars, we would create 3 `with_columns` contexts, that forces Polars to run the 3 pipes sequentially, utilizing zero parallelism. The way to get similar abstractions in polars is creating functions that create expressions. @@ -368,7 +368,7 @@ df.with_columns( ) ``` -If you need the schema in the functions that generate the expressions, you an utilize a single `pipe`: +If you need the schema in the functions that generate the expressions, you can utilize a single `pipe`: ```python from collections import OrderedDict diff --git a/docs/user-guide/misc/alternatives.md b/docs/user-guide/misc/alternatives.md deleted file mode 100644 index 8a301ff4fcaa..000000000000 --- a/docs/user-guide/misc/alternatives.md +++ /dev/null @@ -1,66 +0,0 @@ -# Alternatives - -These are some tools that share similar functionality to what Polars does. - -- Pandas - - A very versatile tool for small data. Read [10 things I hate about pandas](https://wesmckinney.com/blog/apache-arrow-pandas-internals/) - written by the author himself. Polars has solved all those 10 things. - Polars is a versatile tool for small and large data with a more predictable, less ambiguous, and stricter API. - -- Pandas the API - - The API of pandas was designed for in memory data. This makes it a poor fit for performant analysis on large data - (read anything that does not fit into RAM). Any tool that tries to distribute that API will likely have a - suboptimal query plan compared to plans that follow from a declarative API like SQL or Polars' API. - -- Dask - - Parallelizes existing single-threaded libraries like NumPy and pandas. As a consumer of those libraries Dask - therefore has less control over low level performance and semantics. - Those libraries are treated like a black box. - On a single machine the parallelization effort can also be seriously stalled by pandas strings. - Pandas strings, by default, are stored as Python objects in - numpy arrays meaning that any operation on them is GIL bound and therefore single threaded. This can be circumvented - by multi-processing but has a non-trivial cost. - -- Modin - - Similar to Dask - -- Vaex - - Vaexs method of out-of-core analysis is memory mapping files. This works until it doesn't. For instance parquet - or csv files first need to be read and converted to a file format that can be memory mapped. Another downside is - that the OS determines when pages will be swapped. Operations that need a full data shuffle, such as - sorts, have terrible performance on memory mapped data. - Polars' out of core processing is not based on memory mapping, but on streaming data in batches (and spilling to disk - if needed), we control which data must be hold in memory, not the OS, meaning that we don't have unexpected IO stalls. - -- DuckDB - - Polars and DuckDB have many similarities. DuckDB is focused on providing an in-process OLAP Sqlite alternative, - Polars is focused on providing a scalable `DataFrame` interface to many languages. Those different front-ends lead to - different optimization strategies and different algorithm prioritization. The interoperability between both is zero-copy. - See more: https://duckdb.org/docs/guides/python/polars - -- Spark - - Spark is designed for distributed workloads and uses the JVM. The setup for spark is complicated and the startup-time - is slow. On a single machine Polars has much better performance characteristics. If you need to process TB's of data - Spark is a better choice. - -- CuDF - - GPU's and CuDF are fast! - However, GPU's are not readily available and expensive in production. The amount of memory available on a GPU - is often a fraction of the available RAM. - This (and out-of-core) processing means that Polars can handle much larger data-sets. - Next to that Polars can be close in [performance to CuDF](https://zakopilo.hatenablog.jp/entry/2023/02/04/220552). - CuDF doesn't optimize your query, so is not uncommon that on ETL jobs Polars will be faster because it can elide - unneeded work and materializations. - -- Any - - Polars is written in Rust. This gives it strong safety, performance and concurrency guarantees. - Polars is written in a modular manner. Parts of Polars can be used in other query programs and can be added as a library. diff --git a/docs/user-guide/misc/comparison.md b/docs/user-guide/misc/comparison.md new file mode 100644 index 000000000000..3ae31fe0077d --- /dev/null +++ b/docs/user-guide/misc/comparison.md @@ -0,0 +1,35 @@ +# Comparison with other tools + +These are several libraries and tools that share similar functionalities with Polars. This often leads to questions from data experts about what the differences are. Below is a short comparison between some of the more popular data processing tools and Polars, to help data experts make a deliberate decision on which tool to use. + +You can find performance benchmarks (h2oai benchmark) of these tools here: [Polars blog post](https://pola.rs/posts/benchmarks/) or a more recent benchmark [done by DuckDB](https://duckdblabs.github.io/db-benchmark/) + +### Pandas + +Pandas stands as a widely-adopted and comprehensive tool in Python data analysis, renowned for its rich feature set and strong community support. However, due to its single threaded nature, it can struggle with performance and memory usage on medium and large datasets. + +In contrast, Polars is optimised for high-performance multithreaded computing on single nodes, providing significant improvements in speed and memory efficiency, particularly for medium to large data operations. Its more composable and stricter API results in greater expressiveness and fewer schema-related bugs. + +### Dask + +Dask extends Pandas' capabilities to large, distributed datasets. Dask mimics Pandas' API, offering a familiar environment for Pandas users, but with the added benefit of parallel and distributed computing. + +While Dask excels at scaling Pandas workflows across clusters, it only supports a subset of the Pandas API and therefore cannot be used for all use cases. Polars offers a more versatile API that delivers strong performance within the constraints of a single node. + +The choice between Dask and Polars often comes down to familiarity with the Pandas API and the need for distributed processing for extremely large datasets versus the need for efficiency and speed in a vertically scaled environment for a wide range of use cases. + +### Modin + +Similar to Dask. In 2023, Snowflake acquired Ponder, the organisation that maintains Modin. + +### Spark + +Spark (specifically PySpark) represents a different approach to large-scale data processing. While Polars has an optimised performance for single-node environments, Spark is designed for distributed data processing across clusters, making it suitable for extremely large datasets. + +However, Spark's distributed nature can introduce complexity and overhead, especially for small datasets and tasks that can run on a single machine. Another consideration is collaboration between data scientists and engineers. As they typically work with different tools (Pandas and Pyspark), refactoring is often required by engineers to deploy data scientists' data processing pipelines. Polars offers a single syntax that, due to vertical scaling, works in local environments and on a single machine in the cloud. + +The choice between Polars and Spark often depends on the scale of data and the specific requirements of the processing task. If you need to process TBs of data, Spark is a better choice. + +### DuckDB + +Polars and DuckDB have many similarities. However, DuckDB is focused on providing an in-process SQL OLAP database management system, while Polars is focused on providing a scalable `DataFrame` interface to many languages. The different front-ends lead to different optimisation strategies and different algorithm prioritisation. The interoperability between both is zero-copy. DuckDB offers a guide on [how to integrate with Polars](https://duckdb.org/docs/guides/python/polars.html). diff --git a/docs/user-guide/misc/visualization.md b/docs/user-guide/misc/visualization.md new file mode 100644 index 000000000000..88dcd83a18a6 --- /dev/null +++ b/docs/user-guide/misc/visualization.md @@ -0,0 +1,60 @@ +# Visualization + +Data in a Polars `DataFrame` can be visualized using common visualization libraries. + +We illustrate plotting capabilities using the Iris dataset. We scan a CSV and then do a group-by on the `species` column and get the mean of the `petal_length`. + +{{code_block('user-guide/misc/visualization','dataframe',[])}} + +```python exec="on" result="text" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:dataframe" +``` + +## Built-in plotting with hvPlot + +Polars has a `plot` method to create interactive plots using [hvPlot](https://hvplot.holoviz.org/). + +{{code_block('user-guide/misc/visualization','hvplot_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:hvplot_make_plot" +``` + +## Matplotlib + +To create a bar chart we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` for each column. Matplotlib does not have explicit support for Polars objects but Matplotlib can accept a Polars `Series` because it can convert each Series to a numpy array, which is zero-copy for numeric +data without null values. + +{{code_block('user-guide/misc/visualization','matplotlib_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:matplotlib_make_plot" +``` + +## Seaborn, Plotly & Altair + +[Seaborn](https://seaborn.pydata.org/), [Plotly](https://plotly.com/) & [Altair](https://altair-viz.github.io/) can accept a Polars `DataFrame` by leveraging the [dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy conversion where possible. + +### Seaborn + +{{code_block('user-guide/misc/visualization','seaborn_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:seaborn_make_plot" +``` + +### Plotly + +{{code_block('user-guide/misc/visualization','plotly_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:plotly_make_plot" +``` + +### Altair + +{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:altair_make_plot" +``` diff --git a/docs/user-guide/transformations/joins.md b/docs/user-guide/transformations/joins.md index 07dee43127bf..d568cd63069c 100644 --- a/docs/user-guide/transformations/joins.md +++ b/docs/user-guide/transformations/joins.md @@ -65,7 +65,7 @@ The `outer` join produces a `DataFrame` that contains all the rows from both `Da ### Cross join -A `cross` join is a cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. +A `cross` join is a Cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. {{code_block('user-guide/transformations/joins','df3',['DataFrame'])}} diff --git a/examples/datasets/null_nutriscore.csv b/examples/datasets/null_nutriscore.csv new file mode 100644 index 000000000000..0f7922502bc2 --- /dev/null +++ b/examples/datasets/null_nutriscore.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g,nutri_score,proteins_g +seafood,117,9,0,,10 +seafood,201,6,1,,10 +fruit,59,1,14,,10 +meat,97,6,0,,10 +meat,124,12,1,,10 +meat,113,11,1,,10 +vegetables,30,1,1,,10 +seafood,191,6,1,,10 +vegetables,35,0.4,0,,10 +vegetables,21,0,2,,10 +seafood,121,1.5,0,,10 +seafood,125,5,1,,10 +vegetables,21,0,3,,10 +seafood,142,5,0,,10 +meat,118,7,1,,10 +fruit,61,0,12,,10 +fruit,33,1,4,,10 +vegetables,31,0,6,,10 +meat,109,7,2,,10 +vegetables,22,0,1,,10 +fruit,31,0,2,,10 +vegetables,22,0,2,,10 +seafood,155,5,0,,10 +fruit,133,0,27,,10 +seafood,205,9,0,,10 +fruit,72,4.5,7,,10 +fruit,60,1,7,,10 diff --git a/mkdocs.yml b/mkdocs.yml index c26fdd20902e..6673d17741ce 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,14 +1,14 @@ # https://www.mkdocs.org/user-guide/configuration/ # Project information -site_name: Polars User Guide +site_name: Polars user guide site_url: https://docs.pola.rs/ repo_url: https://github.com/pola-rs/polars repo_name: pola-rs/polars # Documentation layout nav: - - User Guide: + - User guide: - index.md - user-guide/getting-started.md - user-guide/installation.md @@ -30,7 +30,7 @@ nav: - user-guide/expressions/casting.md - user-guide/expressions/strings.md - user-guide/expressions/aggregation.md - - user-guide/expressions/null.md + - user-guide/expressions/missing-data.md - user-guide/expressions/window.md - user-guide/expressions/folds.md - user-guide/expressions/lists.md @@ -80,7 +80,8 @@ nav: - user-guide/ecosystem.md - Misc: - user-guide/misc/multiprocessing.md - - user-guide/misc/alternatives.md + - user-guide/misc/visualization.md + - user-guide/misc/comparison.md - API reference: api/index.md @@ -181,4 +182,4 @@ plugins: 'user-guide/basics/index.md': 'user-guide/getting-started.md' 'user-guide/basics/reading-writing.md': 'user-guide/getting-started.md' 'user-guide/basics/expressions.md': 'user-guide/getting-started.md' - 'user-guide/basics/joins.md': 'user-guide/getting-started.md' \ No newline at end of file + 'user-guide/basics/joins.md': 'user-guide/getting-started.md' diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 1b9f07669808..e476573e14ee 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.20.6" +version = "0.20.11" edition = "2021" [lib] @@ -22,6 +22,7 @@ either = { workspace = true } itoa = { workspace = true } libc = "0.2" ndarray = { workspace = true } +num-traits = { workspace = true } numpy = { version = "0.20", default-features = false } once_cell = { workspace = true } pyo3 = { workspace = true, features = ["abi3-py38", "extension-module", "multiple-pymethods"] } @@ -58,6 +59,7 @@ features = [ "lazy", "list_eval", "list_to_struct", + "array_to_struct", "log", "mode", "moment", diff --git a/py-polars/Makefile b/py-polars/Makefile index 9365c9481b0b..420c7f75abf9 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -54,10 +54,15 @@ build-opt-native: .venv ## Same as build-opt, except with native CPU optimizati build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on @$(MAKE) -s -C .. $@ +.PHONY: lint +lint: .venv ## Run lint checks + $(VENV_BIN)/ruff check + -$(VENV_BIN)/mypy + .PHONY: fmt -fmt: .venv ## Run autoformatting and linting - $(VENV_BIN)/ruff check . - $(VENV_BIN)/ruff format . +fmt: .venv ## Run autoformatting (and lint) + $(VENV_BIN)/ruff check + $(VENV_BIN)/ruff format $(VENV_BIN)/typos cargo fmt --all -dprint fmt @@ -65,14 +70,14 @@ fmt: .venv ## Run autoformatting and linting .PHONY: clippy clippy: ## Run clippy - cargo clippy --locked -- -D warnings + cargo clippy --locked -- -D warnings -D clippy::dbg_macro .PHONY: pre-commit -pre-commit: fmt clippy ## Run all code quality checks +pre-commit: fmt clippy ## Run all code formatting and lint/quality checks .PHONY: test test: .venv build ## Run fast unittests - $(VENV_BIN)/pytest -n auto --dist loadgroup + $(VENV_BIN)/pytest -n auto --dist loadgroup $(PYTEST_ARGS) .PHONY: doctest doctest: .venv build ## Run doctests @@ -90,7 +95,6 @@ coverage: .venv build ## Run tests and report coverage .PHONY: clean clean: ## Clean up caches and build artifacts - @rm -rf target/ @rm -rf docs/build/ @rm -rf docs/source/reference/api/ @rm -rf .hypothesis/ @@ -100,8 +104,7 @@ clean: ## Clean up caches and build artifacts @rm -f .coverage @rm -f coverage.xml @rm -f polars/polars.abi3.so - @find . -type f -name '*.py[co]' -delete -or -type d -name __pycache__ -delete - @cargo clean + @find . -type f -name '*.py[co]' -delete -or -type d -name __pycache__ -exec rm -r {} + .PHONY: help help: ## Display this help screen diff --git a/py-polars/debug/launch.py b/py-polars/debug/launch.py new file mode 100644 index 000000000000..95352e4eafa3 --- /dev/null +++ b/py-polars/debug/launch.py @@ -0,0 +1,81 @@ +import os +import re +import sys +import time +from pathlib import Path + +""" +The following parameter determines the sleep time of the Python process after a signal +is sent that attaches the Rust LLDB debugger. If the Rust LLDB debugger attaches to the +current session too late, it might miss any set breakpoints. If this happens +consistently, it is recommended to increase this value. +""" +LLDB_DEBUG_WAIT_TIME_SECONDS = 1 + + +def launch_debugging() -> None: + """ + Debug Rust files via Python. + + Determine the pID for the current debugging session, attach the Rust LLDB launcher, + and execute the originally-requested script. + """ + if len(sys.argv) == 1: + msg = ( + "launch.py is not meant to be executed directly; please use the `Python: " + "Debug Rust` debugging configuration to run a python script that uses the " + "polars library." + ) + raise RuntimeError(msg) + + # Get the current process ID. + pID = os.getpid() + + # Print to the debug console to allow VSCode to pick up on the signal and start the + # Rust LLDB configuration automatically. + launch_file = Path(__file__).parents[2] / ".vscode/launch.json" + if not launch_file.exists(): + msg = f"Cannot locate {launch_file}" + raise RuntimeError(msg) + with launch_file.open("r") as f: + launch_info = f.read() + + # Overwrite the pid found in launch.json with the pid for the current process. + # Match the initial "Rust LLDB" definition with the pid defined immediately after. + pattern = re.compile('("Rust LLDB",\\s*"pid":\\s*")\\d+(")') + found = pattern.search(launch_info) + if not found: + msg = ( + "Cannot locate pid definition in launch.json for Rust LLDB configuration. " + "Please follow the instructions in CONTRIBUTING.md for creating the " + "launch configuration." + ) + raise RuntimeError(msg) + + launch_info_with_new_pid = pattern.sub(rf"\g<1>{pID}\g<2>", launch_info) + with launch_file.open("w") as f: + f.write(launch_info_with_new_pid) + + # Print pID to the debug console. This auto-triggers the Rust LLDB configurations. + print(f"pID = {pID}") + + # Give the LLDB time to connect. Depending on how long it takes for your LLDB + # debugging session to initiatialize, you may have to adjust this setting. + time.sleep(LLDB_DEBUG_WAIT_TIME_SECONDS) + + # Update sys.argv so that when exec() is called, the first argument is the script + # name itself, and the remaining are the input arguments. + sys.argv.pop(0) + with Path(sys.argv[0]).open() as fh: + script_contents = fh.read() + + # Run the originally requested file by reading in the script, compiling, and + # executing the code. + file_to_execute = Path(sys.argv[0]) + exec( + compile(script_contents, file_to_execute, mode="exec"), {"__name__": "__main__"} + ) + + +if __name__ == "__main__": + launch_debugging() diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index dfc9cb34f0b0..75e19f79fcc3 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -1,10 +1,8 @@ ---prefer-binary - numpy pandas pyarrow -hypothesis==6.92.1 +hypothesis==6.97.4 sphinx==7.2.4 diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst index 067393e242c7..dd3d7be45d98 100644 --- a/py-polars/docs/source/reference/expressions/array.rst +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -11,9 +11,12 @@ The following methods are available under the `expr.arr` attribute. Expr.arr.max Expr.arr.min + Expr.arr.median Expr.arr.sum + Expr.arr.std Expr.arr.to_list Expr.arr.unique + Expr.arr.var Expr.arr.all Expr.arr.any Expr.arr.sort @@ -27,3 +30,5 @@ The following methods are available under the `expr.arr` attribute. Expr.arr.explode Expr.arr.contains Expr.arr.count_matches + Expr.arr.to_struct + Expr.arr.shift diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index 3224ea475d07..3fad1cb7f989 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -65,6 +65,7 @@ These functions are available from the polars module root and can be used as exp max max_horizontal mean + mean_horizontal median min min_horizontal diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index 1d794108bb18..d168e3976f02 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -31,6 +31,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.lengths Expr.list.max Expr.list.mean + Expr.list.median Expr.list.min Expr.list.reverse Expr.list.sample @@ -41,9 +42,13 @@ The following methods are available under the `expr.list` attribute. Expr.list.shift Expr.list.slice Expr.list.sort + Expr.list.std Expr.list.sum Expr.list.tail Expr.list.take Expr.list.to_array Expr.list.to_struct Expr.list.unique + Expr.list.n_unique + Expr.list.var + Expr.list.gather_every diff --git a/py-polars/docs/source/reference/expressions/meta.rst b/py-polars/docs/source/reference/expressions/meta.rst index c2bfed4728cf..22868fad271d 100644 --- a/py-polars/docs/source/reference/expressions/meta.rst +++ b/py-polars/docs/source/reference/expressions/meta.rst @@ -17,5 +17,6 @@ The following methods are available under the `expr.meta` attribute. Expr.meta.pop Expr.meta.tree_format Expr.meta.root_names + Expr.meta.serialize Expr.meta.undo_aliases Expr.meta.write_json diff --git a/py-polars/docs/source/reference/expressions/miscellaneous.rst b/py-polars/docs/source/reference/expressions/miscellaneous.rst index a1997ccb9031..c0ea4d2caf1b 100644 --- a/py-polars/docs/source/reference/expressions/miscellaneous.rst +++ b/py-polars/docs/source/reference/expressions/miscellaneous.rst @@ -6,5 +6,6 @@ Miscellaneous .. autosummary:: :toctree: api/ - Expr.from_json - Expr.set_sorted + Expr.deserialize + Expr.from_json + Expr.set_sorted diff --git a/py-polars/docs/source/reference/expressions/name.rst b/py-polars/docs/source/reference/expressions/name.rst index 91d20a9b41b3..c687651d6278 100644 --- a/py-polars/docs/source/reference/expressions/name.rst +++ b/py-polars/docs/source/reference/expressions/name.rst @@ -15,3 +15,6 @@ The following methods are available under the `expr.name` attribute. Expr.name.suffix Expr.name.to_lowercase Expr.name.to_uppercase + Expr.name.map_fields + Expr.name.prefix_fields + Expr.name.suffix_fields diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst index 13d31f1c33e3..d99d14bb5565 100644 --- a/py-polars/docs/source/reference/index.rst +++ b/py-polars/docs/source/reference/index.rst @@ -20,5 +20,5 @@ methods. All classes and functions exposed in ``polars.*`` namespace are public. config exceptions testing - utils sql + metadata diff --git a/py-polars/docs/source/reference/utils.rst b/py-polars/docs/source/reference/metadata.rst similarity index 73% rename from py-polars/docs/source/reference/utils.rst rename to py-polars/docs/source/reference/metadata.rst index 0ee4d3a55054..4d9c0dbf9c60 100644 --- a/py-polars/docs/source/reference/utils.rst +++ b/py-polars/docs/source/reference/metadata.rst @@ -1,6 +1,6 @@ -===== -Utils -===== +======== +Metadata +======== .. currentmodule:: polars .. autosummary:: @@ -9,4 +9,5 @@ Utils build_info get_index_type show_versions + thread_pool_size threadpool_size diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst index 0ea0269e52c5..13f2da759833 100644 --- a/py-polars/docs/source/reference/series/array.rst +++ b/py-polars/docs/source/reference/series/array.rst @@ -11,9 +11,12 @@ The following methods are available under the `Series.arr` attribute. Series.arr.max Series.arr.min + Series.arr.median Series.arr.sum + Series.arr.std Series.arr.to_list Series.arr.unique + Series.arr.var Series.arr.all Series.arr.any Series.arr.sort @@ -26,4 +29,6 @@ The following methods are available under the `Series.arr` attribute. Series.arr.join Series.arr.explode Series.arr.contains - Series.arr.count_matches \ No newline at end of file + Series.arr.count_matches + Series.arr.to_struct + Series.arr.shift \ No newline at end of file diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index cdce24994f76..2398fe0ea24d 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -31,6 +31,7 @@ The following methods are available under the `Series.list` attribute. Series.list.lengths Series.list.max Series.list.mean + Series.list.median Series.list.min Series.list.reverse Series.list.sample @@ -41,9 +42,13 @@ The following methods are available under the `Series.list` attribute. Series.list.shift Series.list.slice Series.list.sort + Series.list.std Series.list.sum Series.list.tail Series.list.take Series.list.to_array Series.list.to_struct Series.list.unique + Series.list.n_unique + Series.list.var + Series.list.gather_every \ No newline at end of file diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index e0de6cb683c7..d7f093484221 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -147,6 +147,7 @@ max, max_horizontal, mean, + mean_horizontal, median, min, min_horizontal, @@ -197,6 +198,13 @@ scan_pyarrow_dataset, ) from polars.lazyframe import InProcessQuery, LazyFrame +from polars.meta import ( + build_info, + get_index_type, + show_versions, + thread_pool_size, + threadpool_size, +) from polars.series import Series from polars.sql import SQLContext from polars.string_cache import ( @@ -206,11 +214,10 @@ using_string_cache, ) from polars.type_aliases import PolarsDataType -from polars.utils import build_info, get_index_type, show_versions, threadpool_size +from polars.utils._polars_version import get_polars_version as _get_polars_version # TODO: remove need for importing wrap utils at top level from polars.utils._wrap import wrap_df, wrap_s # noqa: F401 -from polars.utils.polars_version import get_polars_version as _get_polars_version __version__: str = _get_polars_version() del _get_polars_version @@ -342,6 +349,7 @@ "cum_sum_horizontal", "cumsum_horizontal", "max_horizontal", + "mean_horizontal", "min_horizontal", "sum_horizontal", # polars.functions.lazy @@ -415,6 +423,7 @@ "build_info", "get_index_type", "show_versions", + "thread_pool_size", "threadpool_size", # selectors "selectors", diff --git a/py-polars/polars/api.py b/py-polars/polars/api.py index 4866fcbc349a..f1020b50bc09 100644 --- a/py-polars/polars/api.py +++ b/py-polars/polars/api.py @@ -43,14 +43,13 @@ def __get__(self, instance: NS | None, cls: type[NS]) -> NS | type[NS]: return self._ns ns_instance = self._ns(instance) # type: ignore[call-arg] - setattr(instance, self._accessor, ns_instance) return ns_instance def _create_namespace( name: str, cls: type[Expr | DataFrame | LazyFrame | Series] ) -> Callable[[type[NS]], type[NS]]: - """Register custom namespace against the underlying polars class.""" + """Register custom namespace against the underlying Polars class.""" def namespace(ns_class: type[NS]) -> type[NS]: if name in _reserved_namespaces: @@ -72,7 +71,7 @@ def namespace(ns_class: type[NS]) -> type[NS]: def register_expr_namespace(name: str) -> Callable[[type[NS]], type[NS]]: """ - Decorator for registering custom functionality with a polars Expr. + Decorator for registering custom functionality with a Polars Expr. Parameters ---------- @@ -125,7 +124,7 @@ def register_expr_namespace(name: str) -> Callable[[type[NS]], type[NS]]: def register_dataframe_namespace(name: str) -> Callable[[type[NS]], type[NS]]: """ - Decorator for registering custom functionality with a polars DataFrame. + Decorator for registering custom functionality with a Polars DataFrame. Parameters ---------- @@ -223,7 +222,7 @@ def register_dataframe_namespace(name: str) -> Callable[[type[NS]], type[NS]]: def register_lazyframe_namespace(name: str) -> Callable[[type[NS]], type[NS]]: """ - Decorator for registering custom functionality with a polars LazyFrame. + Decorator for registering custom functionality with a Polars LazyFrame. Parameters ---------- diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 889e1cf011b8..a481597f3ae0 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -305,7 +305,7 @@ def save_to_file(cls, file: Path | str) -> None: Examples -------- - >>> json_file = pl.Config().save("~/polars/config.json") # doctest: +SKIP + >>> pl.Config().save_to_file("~/polars/config.json") # doctest: +SKIP """ file = Path(normalize_filepath(file)).resolve() file.write_text(cls.save()) @@ -608,7 +608,7 @@ def set_fmt_float(cls, fmt: FloatFmt | None = "mixed") -> type[Config]: How to format floating point numbers: - "mixed": Limit the number of decimal places and use scientific - notation for large/small values. + notation for large/small values. - "full": Print the full precision of the floating point number. Examples diff --git a/py-polars/polars/convert.py b/py-polars/polars/convert.py index 1dec0fcd1012..ed5ea49f21d7 100644 --- a/py-polars/polars/convert.py +++ b/py-polars/polars/convert.py @@ -107,8 +107,8 @@ def from_dicts( schema_overrides : dict, default None Support override of inferred types for one or more columns. infer_schema_length - How many dictionaries/rows to scan to determine the data types - if set to `None` then ALL dicts are scanned; this will be slow. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. Returns ------- @@ -211,8 +211,8 @@ def from_records( the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. infer_schema_length - How many dictionaries/rows to scan to determine the data types - if set to `None` all rows are scanned. This will be slow. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. Returns ------- @@ -588,9 +588,12 @@ def from_arrow( 3 ] """ # noqa: W505 - if isinstance(data, pa.Table): + if isinstance(data, (pa.Table, pa.RecordBatch)): return pl.DataFrame._from_arrow( - data=data, rechunk=rechunk, schema=schema, schema_overrides=schema_overrides + data=data, + rechunk=rechunk, + schema=schema, + schema_overrides=schema_overrides, ) elif isinstance(data, (pa.Array, pa.ChunkedArray)): name = getattr(data, "_name", "") or "" @@ -606,8 +609,6 @@ def from_arrow( schema_overrides=schema_overrides, ) - if isinstance(data, pa.RecordBatch): - data = [data] if isinstance(data, Iterable): return pl.DataFrame._from_arrow( data=pa.Table.from_batches( @@ -649,7 +650,7 @@ def from_pandas( def from_pandas( - data: pd.DataFrame | pd.Series[Any] | pd.Index[Any], + data: pd.DataFrame | pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, schema_overrides: SchemaDict | None = None, rechunk: bool = True, @@ -712,7 +713,7 @@ def from_pandas( 3 ] """ - if isinstance(data, (pd.Series, pd.DatetimeIndex)): + if isinstance(data, (pd.Series, pd.Index, pd.DatetimeIndex)): return pl.Series._from_pandas("", data, nan_to_null=nan_to_null) elif isinstance(data, pd.DataFrame): return pl.DataFrame._from_pandas( diff --git a/py-polars/polars/dataframe/_html.py b/py-polars/polars/dataframe/_html.py index 99f52ff94dc3..06e729241f44 100644 --- a/py-polars/polars/dataframe/_html.py +++ b/py-polars/polars/dataframe/_html.py @@ -58,15 +58,16 @@ def __init__( self.elements: list[str] = [] self.max_cols = max_cols self.max_rows = max_rows - self.series = from_series + self.from_series = from_series self.row_idx: Iterable[int] self.col_idx: Iterable[int] if max_rows < df.height: + half, rest = divmod(max_rows, 2) self.row_idx = [ - *list(range(max_rows // 2)), + *list(range(half + rest)), -1, - *list(range(df.height - max_rows // 2, df.height)), + *list(range(df.height - half, df.height)), ] else: self.row_idx = range(df.height) @@ -132,7 +133,7 @@ def render(self) -> list[str]: ): # format frame/series shape with '_' thousand-separators s = self.df.shape - shape = f"({s[0]:_},)" if self.series else f"({s[0]:_}, {s[1]:_})" + shape = f"({s[0]:_},)" if self.from_series else f"({s[0]:_}, {s[1]:_})" self.elements.append(f"shape: {shape}") diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 052ec096fec3..235b1b47f214 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -13,7 +13,6 @@ IO, TYPE_CHECKING, Any, - BinaryIO, Callable, ClassVar, Collection, @@ -50,7 +49,6 @@ _check_for_numpy, _check_for_pandas, _check_for_pyarrow, - dataframe_api_compat, hvplot, import_optional, ) @@ -79,7 +77,6 @@ from polars.slice import PolarsSlice from polars.type_aliases import DbWriteMode from polars.utils._construction import ( - _post_apply_columns, arrow_to_pydf, dict_to_pydf, frame_to_pydf, @@ -227,9 +224,11 @@ class DataFrame: Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. - infer_schema_length : int, default None - Maximum number of rows to read for schema inference; only applies if the input - data is a sequence or generator of rows; other input is read as-is. + infer_schema_length : int or None + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. + This parameter only applies if the input data is a sequence or generator of + rows; other input is read as-is. nan_to_null : bool, default False If the data comes from one or more numpy arrays, can optionally convert input data np.nan values to null instead. This is a no-op for all other input data. @@ -348,6 +347,8 @@ class DataFrame: False """ + __slots__ = ("_df",) + _df: PyDataFrame _accessors: ClassVar[set[str]] = {"plot"} def __init__( @@ -433,24 +434,6 @@ def _from_pydf(cls, py_df: PyDataFrame) -> Self: df._df = py_df return df - @classmethod - def _from_dicts( - cls, - data: Sequence[dict[str, Any]], - schema: SchemaDefinition | None = None, - *, - schema_overrides: SchemaDict | None = None, - infer_schema_length: int | None = N_INFER_DEFAULT, - ) -> Self: - pydf = PyDataFrame.read_dicts( - data, infer_schema_length, schema, schema_overrides - ) - if schema or schema_overrides: - pydf = _post_apply_columns( - pydf, list(schema or pydf.columns()), schema_overrides=schema_overrides - ) - return cls._from_pydf(pydf) - @classmethod def _from_dict( cls, @@ -498,29 +481,9 @@ def _from_records( """ Construct a DataFrame from a sequence of sequences. - Parameters - ---------- - data : Sequence of sequences - Two-dimensional data represented as a sequence of sequences. - schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict - The DataFrame schema may be declared in several ways: - - * As a dict of {name:type} pairs; if type is None, it will be auto-inferred. - * As a list of column names; in this case types are automatically inferred. - * As a list of (name,type) pairs; this is equivalent to the dictionary form. - - If you supply a list of column names that does not match the names in the - underlying data, the names given here will overwrite them. The number - of names given in the schema should match the underlying data dimensions. - schema_overrides : dict, default None - Support type specification or override of one or more columns; note that - any dtypes inferred from the columns param will be overridden. - orient : {'col', 'row'}, default None - Whether to interpret two-dimensional data as columns or as rows. If None, - the orientation is inferred by matching the columns and data dimensions. If - this does not yield conclusive results, column orientation is used. - infer_schema_length - How many rows to scan to determine the column type. + See Also + -------- + polars.io.from_records """ return cls._from_pydf( sequence_to_pydf( @@ -575,7 +538,7 @@ def _from_numpy( @classmethod def _from_arrow( cls, - data: pa.Table, + data: pa.Table | pa.RecordBatch, schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, @@ -589,8 +552,8 @@ def _from_arrow( Parameters ---------- - data : arrow table, array, or sequence of sequences - Data representing an Arrow Table or Array. + data : arrow Table, RecordBatch, or sequence of sequences + Data representing an Arrow Table or RecordBatch. schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict The DataFrame schema may be declared in several ways: @@ -886,7 +849,7 @@ def _read_parquet( @classmethod def _read_avro( cls, - source: str | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: Sequence[int] | Sequence[str] | None = None, n_rows: int | None = None, @@ -1050,9 +1013,9 @@ def _read_json( cls, source: str | Path | IOBase | bytes, *, - infer_schema_length: int | None = N_INFER_DEFAULT, schema: SchemaDefinition | None = None, schema_overrides: SchemaDefinition | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Self: """ Read into a DataFrame from a JSON file. @@ -1376,19 +1339,6 @@ def __dataframe__( return PolarsDataFrame(self, allow_copy=allow_copy) - def __dataframe_consortium_standard__( - self, *, api_version: str | None = None - ) -> Any: - """ - Provide entry point to the Consortium DataFrame Standard API. - - This is developed and maintained outside of polars. - Please report any issues to https://github.com/data-apis/dataframe-api-compat. - """ - return dataframe_api_compat.polars_standard.convert_to_standard_compliant_dataframe( - self.lazy(), api_version=api_version - ) - def _comp(self, other: Any, op: ComparisonOperator) -> DataFrame: """Compare a DataFrame with another object.""" if isinstance(other, DataFrame): @@ -1858,7 +1808,7 @@ def __deepcopy__(self, memo: None = None) -> Self: def _ipython_key_completions_(self) -> list[str]: return self.columns - def _repr_html_(self, **kwargs: Any) -> str: + def _repr_html_(self, *, _from_series: bool = False) -> str: """ Format output data in HTML for display in Jupyter Notebooks. @@ -1870,18 +1820,18 @@ def _repr_html_(self, **kwargs: Any) -> str: """ max_cols = int(os.environ.get("POLARS_FMT_MAX_COLS", default=75)) if max_cols < 0: - max_cols = self.shape[1] - max_rows = int(os.environ.get("POLARS_FMT_MAX_ROWS", default=25)) + max_cols = self.width + + max_rows = int(os.environ.get("POLARS_FMT_MAX_ROWS", default=10)) if max_rows < 0: - max_rows = self.shape[0] + max_rows = self.height - from_series = kwargs.get("from_series", False) return "".join( NotebookFormatter( self, max_cols=max_cols, max_rows=max_rows, - from_series=from_series, + from_series=_from_series, ).render() ) @@ -2099,18 +2049,21 @@ def to_numpy( structured: bool = False, # noqa: FBT001 *, order: IndexOrder = "fortran", + allow_copy: bool = True, + writable: bool = False, use_pyarrow: bool = True, ) -> np.ndarray[Any, Any]: """ - Convert DataFrame to a 2D NumPy array. - - This operation clones data. + Convert this DataFrame to a NumPy ndarray. Parameters ---------- structured - Optionally return a structured array, with field names and - dtypes that correspond to the DataFrame schema. + Return a `structured array`_ with a data type that corresponds to the + DataFrame schema. If set to `False` (default), a 2D ndarray is + returned instead. + + .. _structured array: https://numpy.org/doc/stable/user/basics.rec.html order The index order of the returned NumPy array, either C-like or Fortran-like. In general, using the Fortran-like index order is faster. @@ -2119,17 +2072,19 @@ def to_numpy( one-dimensional array. Note that this option only takes effect if `structured` is set to `False` and the DataFrame dtypes allow for a global dtype for all columns. + allow_copy + Allow memory to be copied to perform the conversion. If set to `False`, + causes conversions that are not zero-copy to fail. + writable + Ensure the resulting array is writable. This will force a copy of the data + if the array was created without copy, as the underlying Arrow data is + immutable. use_pyarrow Use `pyarrow.Array.to_numpy `_ function for the conversion to numpy if necessary. - Notes - ----- - If you're attempting to convert String or Decimal to an array, you'll need to - install `pyarrow`. - Examples -------- >>> df = pl.DataFrame( @@ -2162,32 +2117,49 @@ def to_numpy( rec.array([(1, 6.5, 'a'), (2, 7. , 'b'), (3, 8.5, 'c')], dtype=[('foo', 'u1'), ('bar', ' None: + if not allow_copy and not self.is_empty(): + msg = f"copy not allowed: {msg}" + raise RuntimeError(msg) + if structured: - # see: https://numpy.org/doc/stable/user/basics.rec.html - arrays = [] - for c, tp in self.schema.items(): - s = self[c] - a = s.to_numpy(use_pyarrow=use_pyarrow) - arrays.append( - a.astype(str, copy=False) - if tp == String and not s.null_count() - else a - ) + raise_on_copy("cannot create structured array without copying data") - out = np.empty( - len(self), dtype=list(zip(self.columns, (a.dtype for a in arrays))) - ) + arrays = [] + struct_dtype = [] + for s in self.iter_columns(): + arr = s.to_numpy(use_pyarrow=use_pyarrow) + if s.dtype == String and s.null_count() == 0: + arr = arr.astype(str, copy=False) + arrays.append(arr) + struct_dtype.append((s.name, arr.dtype)) + + out = np.empty(self.height, dtype=struct_dtype) for idx, c in enumerate(self.columns): out[c] = arrays[idx] - else: - out = self._df.to_numpy(order) - if out is None: - return np.vstack( - [ - self.to_series(i).to_numpy(use_pyarrow=use_pyarrow) - for i in range(self.width) - ] - ).T + return out + + if order == "fortran": + array = self._df.to_numpy_view() + if array is not None: + if writable and not array.flags.writeable: + raise_on_copy("cannot create writable array without copying data") + array = array.copy() + return array + + raise_on_copy( + "only numeric data without nulls in Fortran-like order can be converted without copy" + ) + + out = self._df.to_numpy(order) + if out is None: + return np.vstack( + [ + self.to_series(i).to_numpy(use_pyarrow=use_pyarrow) + for i in range(self.width) + ] + ).T return out @@ -2479,7 +2451,7 @@ def write_json( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. pretty Pretty serialize json. @@ -2535,7 +2507,7 @@ def write_ndjson(self, file: IOBase | str | Path | None = None) -> str | None: Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. Examples @@ -2589,7 +2561,7 @@ def write_csv( @overload def write_csv( self, - file: BytesIO | TextIOWrapper | str | Path, + file: str | Path | IO[str] | IO[bytes], *, include_bom: bool = ..., include_header: bool = ..., @@ -2610,7 +2582,7 @@ def write_csv( @deprecate_renamed_parameter("has_header", "include_header", version="0.19.13") def write_csv( self, - file: BytesIO | TextIOWrapper | str | Path | None = None, + file: str | Path | IO[str] | IO[bytes] | None = None, *, include_bom: bool = False, include_header: bool = True, @@ -2631,7 +2603,7 @@ def write_csv( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. include_bom Whether to include UTF-8 BOM in the CSV output. @@ -2732,7 +2704,7 @@ def write_csv( def write_avro( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: AvroCompression = "uncompressed", name: str = "", ) -> None: @@ -2742,7 +2714,7 @@ def write_avro( Parameters ---------- file - File path or writeable file-like object to which the data will be written. + File path or writable file-like object to which the data will be written. compression : {'uncompressed', 'snappy', 'deflate'} Compression method. Defaults to "uncompressed". name @@ -2774,7 +2746,7 @@ def write_avro( @deprecate_renamed_parameter("has_header", "include_header", version="0.19.13") def write_excel( self, - workbook: Workbook | BytesIO | Path | str | None = None, + workbook: Workbook | IO[bytes] | Path | str | None = None, worksheet: str | None = None, *, position: tuple[int, int] | str = "A1", @@ -2933,7 +2905,7 @@ def write_excel( "A2" indicates the split occurs at the top-left of cell A2, which is the equivalent of (1, 0). * If (row, col, top_row, top_col) are supplied, the panes are split based on - the `row` and `col`, and the scrolling region is inititalized to begin at + the `row` and `col`, and the scrolling region is initialized to begin at the `top_row` and `top_col`. Thus, to freeze only the top row and have the scrolling region begin at row 10, column D (5th col), supply (1, 0, 9, 4). Using cell notation for (row, col), supplying ("A2", 9, 4) is equivalent. @@ -3288,7 +3260,7 @@ def write_ipc( @overload def write_ipc( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: IpcCompression = "uncompressed", *, future: bool = False, @@ -3297,7 +3269,7 @@ def write_ipc( def write_ipc( self, - file: BinaryIO | BytesIO | str | Path | None, + file: str | Path | IO[bytes] | None, compression: IpcCompression = "uncompressed", *, future: bool = False, @@ -3310,7 +3282,7 @@ def write_ipc( Parameters ---------- file - Path or writeable file-like object to which the IPC data will be + Path or writable file-like object to which the IPC data will be written. If set to `None`, the output is returned as a BytesIO object. compression : {'uncompressed', 'lz4', 'zstd'} Compression method. Defaults to "uncompressed". @@ -3364,14 +3336,14 @@ def write_ipc_stream( @overload def write_ipc_stream( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: IpcCompression = "uncompressed", ) -> None: ... def write_ipc_stream( self, - file: BinaryIO | BytesIO | str | Path | None, + file: str | Path | IO[bytes] | None, compression: IpcCompression = "uncompressed", ) -> BytesIO | None: """ @@ -3382,7 +3354,7 @@ def write_ipc_stream( Parameters ---------- file - Path or writeable file-like object to which the IPC record batch data will + Path or writable file-like object to which the IPC record batch data will be written. If set to `None`, the output is returned as a BytesIO object. compression : {'uncompressed', 'lz4', 'zstd'} Compression method. Defaults to "uncompressed". @@ -3431,7 +3403,7 @@ def write_parquet( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. compression : {'lz4', 'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'zstd'} Choose "zstd" for good compression performance. Choose "lz4" for fast compression/decompression. @@ -4303,7 +4275,7 @@ def glimpse( *, max_items_per_column: int = ..., max_colname_length: int = ..., - return_as_string: Literal[False], + return_as_string: Literal[False] = ..., ) -> None: ... @@ -4317,6 +4289,16 @@ def glimpse( ) -> str: ... + @overload + def glimpse( + self, + *, + max_items_per_column: int = ..., + max_colname_length: int = ..., + return_as_string: bool, + ) -> str | None: + ... + def glimpse( self, *, @@ -4475,10 +4457,11 @@ def describe( Customize which percentiles are displayed, applying linear interpolation: - >>> df.describe( - ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], - ... interpolation="linear", - ... ) + >>> with pl.Config(tbl_rows=12): + ... df.describe( + ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], + ... interpolation="linear", + ... ) shape: (11, 7) ┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐ │ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │ @@ -5362,22 +5345,21 @@ def with_row_count(self, name: str = "row_nr", offset: int = 0) -> Self: """ return self.with_row_index(name, offset) + @deprecate_parameter_as_positional("by", version="0.20.7") def group_by( self, - by: IntoExpr | Iterable[IntoExpr], - *more_by: IntoExpr, + *by: IntoExpr | Iterable[IntoExpr], maintain_order: bool = False, + **named_by: IntoExpr, ) -> GroupBy: """ Start a group by operation. Parameters ---------- - by + *by Column(s) to group by. Accepts expression input. Strings are parsed as column names. - *more_by - Additional columns to group by, specified as positional arguments. maintain_order Ensure that the order of the groups is consistent with the input data. This is slower than a default group by. @@ -5387,6 +5369,9 @@ def group_by( .. note:: Within each group, the order of rows is always preserved, regardless of this argument. + **named_by + Additional columns to group by, specified as keyword arguments. + The columns will be renamed to the keyword used. Returns ------- @@ -5498,7 +5483,7 @@ def group_by( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - return GroupBy(self, by, *more_by, maintain_order=maintain_order) + return GroupBy(self, *by, **named_by, maintain_order=maintain_order) def rolling( self, @@ -5511,7 +5496,7 @@ def rolling( check_sorted: bool = True, ) -> RollingGroupBy: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. Different from a `group_by_dynamic` the windows are now determined by the individual values and are not of constant intervals. For constant intervals use @@ -5555,11 +5540,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - **"1i" # length 1** - - **"10i" # length 10** - Parameters ---------- index_column @@ -5569,8 +5549,8 @@ def rolling( then it must be sorted in ascending order within each group). In case of a rolling operation on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -6289,7 +6269,7 @@ def join( * *outer_coalesce* Same as 'outer', but coalesces the key columns * *cross* - Returns the cartisian product of rows from both tables + Returns the Cartesian product of rows from both tables * *semi* Filter rows that have a match in the right table. * *anti* @@ -6404,7 +6384,7 @@ def join( Notes ----- - For joining on columns with categorical data, see `pl.StringCache()`. + For joining on columns with categorical data, see :class:`polars.StringCache`. """ if not isinstance(other, DataFrame): msg = f"expected `other` join table to be a DataFrame, got {type(other).__name__!r}" @@ -7222,7 +7202,7 @@ def explode( ---------- columns Column names, expressions, or a selector defining them. The underlying - columns being exploded must be of List or String datatype. + columns being exploded must be of the `List` or `Array` data type. *more_columns Additional names of columns to explode, specified as positional arguments. @@ -7269,6 +7249,15 @@ def explode( """ return self.lazy().explode(columns, *more_columns).collect(_eager=True) + @deprecate_nonkeyword_arguments( + allowed_args=["self"], + message=( + "The order of the parameters of `pivot` will change in the next breaking release." + " The order will become `index, columns, values` with `values` as an optional parameter." + " Use keyword arguments to silence this warning." + ), + version="0.20.8", + ) def pivot( self, values: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None, @@ -7290,7 +7279,8 @@ def pivot( ---------- values Column values to aggregate. Can be multiple columns if the *columns* - arguments contains multiple columns as well. + arguments contains multiple columns as well. If None, all remaining columns + will be used. index One or multiple keys to group by. columns @@ -7323,7 +7313,7 @@ def pivot( ... "baz": [1, 2, 3, 4, 5, 6], ... } ... ) - >>> df.pivot(values="baz", index="foo", columns="bar", aggregate_function="sum") + >>> df.pivot(index="foo", columns="bar", values="baz", aggregate_function="sum") shape: (2, 3) ┌─────┬─────┬─────┐ │ foo ┆ y ┆ x │ @@ -7338,25 +7328,25 @@ def pivot( >>> import polars.selectors as cs >>> df.pivot( - ... values=cs.numeric(), ... index=cs.string(), ... columns=cs.string(), + ... values=cs.numeric(), ... aggregate_function="sum", ... sort_columns=True, ... ).sort( ... by=cs.string(), ... ) shape: (4, 6) - ┌─────┬─────┬──────┬──────┬──────┬──────┐ - │ foo ┆ bar ┆ one ┆ two ┆ x ┆ y │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪══════╪══════╪══════╪══════╡ - │ one ┆ x ┆ 5 ┆ null ┆ 5 ┆ null │ - │ one ┆ y ┆ 3 ┆ null ┆ null ┆ 3 │ - │ two ┆ x ┆ null ┆ 10 ┆ 10 ┆ null │ - │ two ┆ y ┆ null ┆ 3 ┆ null ┆ 3 │ - └─────┴─────┴──────┴──────┴──────┴──────┘ + ┌─────┬─────┬─────────────┬─────────────┬─────────────┬─────────────┐ + │ foo ┆ bar ┆ {"one","x"} ┆ {"one","y"} ┆ {"two","x"} ┆ {"two","y"} │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════════════╪═════════════╪═════════════╪═════════════╡ + │ one ┆ x ┆ 5 ┆ null ┆ null ┆ null │ + │ one ┆ y ┆ null ┆ 3 ┆ null ┆ null │ + │ two ┆ x ┆ null ┆ null ┆ 10 ┆ null │ + │ two ┆ y ┆ null ┆ null ┆ null ┆ 3 │ + └─────┴─────┴─────────────┴─────────────┴─────────────┴─────────────┘ Run an expression as aggregation function @@ -7392,17 +7382,10 @@ def pivot( >>> values = pl.col("col3") >>> unique_column_values = ["x", "y"] >>> aggregate_function = lambda col: col.tanh().mean() - >>> ( - ... df.lazy() - ... .group_by(index) - ... .agg( - ... *[ - ... aggregate_function(values.filter(columns == value)).alias(value) - ... for value in unique_column_values - ... ] - ... ) - ... .collect() - ... ) # doctest: +IGNORE_RESULT + >>> df.lazy().group_by(index).agg( + ... aggregate_function(values.filter(columns == value)).alias(value) + ... for value in unique_column_values + ... ).collect() # doctest: +IGNORE_RESULT shape: (2, 3) ┌──────┬──────────┬──────────┐ │ col1 ┆ x ┆ y │ @@ -7413,9 +7396,10 @@ def pivot( │ b ┆ 0.964028 ┆ 0.999954 │ └──────┴──────────┴──────────┘ """ # noqa: W505 - values = _expand_selectors(self, values) index = _expand_selectors(self, index) columns = _expand_selectors(self, columns) + if values is not None: + values = _expand_selectors(self, values) if isinstance(aggregate_function, str): if aggregate_function == "first": @@ -7451,9 +7435,9 @@ def pivot( return self._from_pydf( self._df.pivot_expr( - values, index, columns, + values, maintain_order, sort_columns, aggregate_expr, @@ -9191,7 +9175,7 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i In aggregate context there is also an equivalent method for returning the unique values per-group: - >>> df_agg_nunique = df.group_by(by=["a"]).n_unique() + >>> df_agg_nunique = df.group_by(["a"]).n_unique() Examples -------- @@ -9233,10 +9217,16 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i df = self.lazy().select(expr.n_unique()).collect(_eager=True) return 0 if df.is_empty() else df.row(0)[0] + @deprecate_function( + "Use `select(pl.all().approx_n_unique())` instead.", version="0.20.11" + ) def approx_n_unique(self) -> DataFrame: """ Approximate count of unique values. + .. deprecated:: 0.20.11 + Use `select(pl.all().approx_n_unique())` instead. + This is done using the HyperLogLog++ algorithm for cardinality estimation. Examples @@ -9247,7 +9237,7 @@ def approx_n_unique(self) -> DataFrame: ... "b": [1, 2, 1, 1], ... } ... ) - >>> df.approx_n_unique() + >>> df.approx_n_unique() # doctest: +SKIP shape: (1, 2) ┌─────┬─────┐ │ a ┆ b │ @@ -9598,6 +9588,10 @@ def rows( """ Returns all data in the DataFrame as a list of rows of python-native values. + By default, each row is returned as a tuple of values given in the same order + as the frame columns. Setting `named=True` will return rows of dictionaries + instead. + Parameters ---------- named @@ -9616,12 +9610,13 @@ def rows( -------- Row-iteration is not optimal as the underlying data is stored in columnar form; where possible, prefer export via one of the dedicated export/output methods. - Where possible you should also consider using `iter_rows` instead to avoid - materialising all the data at once. + You should also consider using `iter_rows` instead, to avoid materialising all + the data at once; there is little performance difference between the two, but + peak memory can be reduced if processing rows in batches. Returns ------- - list of tuples (default) or dictionaries of row values + list of row value tuples (default), or list of dictionaries (if `named=True`). See Also -------- @@ -9661,7 +9656,10 @@ def rows_by_key( unique: bool = False, ) -> dict[Any, Iterable[Any]]: """ - Returns DataFrame data as a keyed dictionary of python-native values. + Returns all data as a dictionary of python-native values keyed by some column. + + This method is like `rows`, but instead of returning rows in a flat list, rows + are grouped by the values in the `key` column(s) and returned as a dictionary. Note that this method should not be used in place of native operations, due to the high cost of materializing all frame data out into a dictionary; it should @@ -9904,17 +9902,17 @@ def iter_rows( def iter_columns(self) -> Iterator[Series]: """ - Returns an iterator over the DataFrame's columns. + Returns an iterator over the columns of this DataFrame. + + Yields + ------ + Series Notes ----- Consider whether you can use :func:`all` instead. If you can, it will be more efficient. - Returns - ------- - Iterator of Series. - Examples -------- >>> df = pl.DataFrame( @@ -9955,7 +9953,8 @@ def iter_columns(self) -> Iterator[Series]: │ 10 ┆ 12 │ └─────┴─────┘ """ - return (wrap_s(s) for s in self._df.get_columns()) + for s in self._df.get_columns(): + yield wrap_s(s) def iter_slices(self, n_rows: int = 10_000) -> Iterator[DataFrame]: r""" diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index fa11d5a65946..fd89b8256bd1 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -35,9 +35,9 @@ class GroupBy: def __init__( self, df: DataFrame, - by: IntoExpr | Iterable[IntoExpr], - *more_by: IntoExpr, + *by: IntoExpr | Iterable[IntoExpr], maintain_order: bool, + **named_by: IntoExpr, ): """ Utility class for performing a group by operation over the given DataFrame. @@ -48,18 +48,19 @@ def __init__( ---------- df DataFrame to perform the group by operation over. - by + *by Column or columns to group by. Accepts expression input. Strings are parsed as column names. - *more_by - Additional columns to group by, specified as positional arguments. maintain_order Ensure that the order of the groups is consistent with the input data. This is slower than a default group by. + **named_by + Additional column(s) to group by, specified as keyword arguments. + The columns will be named as the keyword used. """ self.df = df self.by = by - self.more_by = more_by + self.named_by = named_by self.maintain_order = maintain_order def __iter__(self) -> Self: @@ -99,7 +100,7 @@ def __iter__(self) -> Self: temp_col = "__POLARS_GB_GROUP_INDICES" groups_df = ( self.df.lazy() - .group_by(self.by, *self.more_by, maintain_order=self.maintain_order) + .group_by(*self.by, **self.named_by, maintain_order=self.maintain_order) .agg(F.first().agg_groups().alias(temp_col)) .collect(no_optimization=True) ) @@ -107,11 +108,13 @@ def __iter__(self) -> Self: group_names = groups_df.select(F.all().exclude(temp_col)) self._group_names: Iterator[object] | Iterator[tuple[object, ...]] - key_as_single_value = isinstance(self.by, str) and not self.more_by + key_as_single_value = ( + len(self.by) == 1 and isinstance(self.by[0], str) and not self.named_by + ) if key_as_single_value: issue_deprecation_warning( "`group_by` iteration will change to always return group identifiers as tuples." - f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by!r}])`.", + f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by[0]!r}])`.", version="0.20.4", ) self._group_names = iter(group_names.to_series()) @@ -242,7 +245,7 @@ def agg( """ return ( self.df.lazy() - .group_by(self.by, *self.more_by, maintain_order=self.maintain_order) + .group_by(*self.by, **self.named_by, maintain_order=self.maintain_order) .agg(*aggs, **named_aggs) .collect(no_optimization=True) ) @@ -308,24 +311,17 @@ def map_groups(self, function: Callable[[DataFrame], DataFrame]) -> DataFrame: ... pl.int_range(pl.len()).shuffle().over("color") < 2 ... ) # doctest: +IGNORE_RESULT """ - by: list[str] - - if isinstance(self.by, str): - by = [self.by] - elif isinstance(self.by, Iterable) and all(isinstance(c, str) for c in self.by): - by = list(self.by) # type: ignore[arg-type] - else: - msg = "cannot call `map_groups` when grouping by an expression" + if self.named_by: + msg = "cannot call `map_groups` when grouping by named expressions" raise TypeError(msg) - - if all(isinstance(c, str) for c in self.more_by): - by.extend(self.more_by) # type: ignore[arg-type] - else: + if not all(isinstance(c, str) for c in self.by): msg = "cannot call `map_groups` when grouping by an expression" raise TypeError(msg) return self.df.__class__._from_pydf( - self.df._df.group_by_map_groups(by, function, self.maintain_order) + self.df._df.group_by_map_groups( + list(self.by), function, self.maintain_order + ) ) def head(self, n: int = 5) -> DataFrame: @@ -375,7 +371,7 @@ def head(self, n: int = 5) -> DataFrame: """ return ( self.df.lazy() - .group_by(self.by, *self.more_by, maintain_order=self.maintain_order) + .group_by(*self.by, **self.named_by, maintain_order=self.maintain_order) .head(n) .collect(no_optimization=True) ) @@ -427,7 +423,7 @@ def tail(self, n: int = 5) -> DataFrame: """ return ( self.df.lazy() - .group_by(self.by, *self.more_by, maintain_order=self.maintain_order) + .group_by(*self.by, **self.named_by, maintain_order=self.maintain_order) .tail(n) .collect(no_optimization=True) ) diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index d9e2fde8614c..8d85806b4cd7 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -90,6 +90,8 @@ def is_nested(self) -> bool: # noqa: D102 class DataType(metaclass=DataTypeClass): """Base class for all Polars data types.""" + __slots__ = () + def __reduce__(self) -> Any: return (_custom_reconstruct, (type(self), object, None), self.__dict__) @@ -235,6 +237,7 @@ def _custom_reconstruct( class DataTypeGroup(frozenset): # type: ignore[type-arg] """Group of data types.""" + __slots__ = ("_match_base_type",) _match_base_type: bool def __new__( @@ -267,70 +270,104 @@ def __contains__(self, item: Any) -> bool: class NumericType(DataType): """Base class for numeric data types.""" + __slots__ = () + class IntegerType(NumericType): """Base class for integer data types.""" + __slots__ = () + class SignedIntegerType(IntegerType): """Base class for signed integer data types.""" + __slots__ = () + class UnsignedIntegerType(IntegerType): """Base class for unsigned integer data types.""" + __slots__ = () + class FloatType(NumericType): """Base class for float data types.""" + __slots__ = () + class TemporalType(DataType): """Base class for temporal data types.""" + __slots__ = () + class NestedType(DataType): """Base class for nested data types.""" + __slots__ = () + class Int8(SignedIntegerType): """8-bit signed integer type.""" + __slots__ = () + class Int16(SignedIntegerType): """16-bit signed integer type.""" + __slots__ = () + class Int32(SignedIntegerType): """32-bit signed integer type.""" + __slots__ = () + class Int64(SignedIntegerType): """64-bit signed integer type.""" + __slots__ = () + class UInt8(UnsignedIntegerType): """8-bit unsigned integer type.""" + __slots__ = () + class UInt16(UnsignedIntegerType): """16-bit unsigned integer type.""" + __slots__ = () + class UInt32(UnsignedIntegerType): """32-bit unsigned integer type.""" + __slots__ = () + class UInt64(UnsignedIntegerType): """64-bit unsigned integer type.""" + __slots__ = () + class Float32(FloatType): """32-bit floating point type.""" + __slots__ = () + class Float64(FloatType): """64-bit floating point type.""" + __slots__ = () + class Decimal(NumericType): """ @@ -340,8 +377,17 @@ class Decimal(NumericType): This functionality is considered **unstable**. It is a work-in-progress feature and may not always work as expected. It may be changed at any point without it being considered a breaking change. + + Parameters + ---------- + precision + Maximum number of digits in each number. + If set to `None` (default), the precision is inferred. + scale + Number of digits to the right of the decimal point in each number. """ + __slots__ = ("precision", "scale") precision: int | None scale: int @@ -383,10 +429,14 @@ def __hash__(self) -> int: class Boolean(DataType): """Boolean type.""" + __slots__ = () + class String(DataType): """UTF-8 encoded string type.""" + __slots__ = () + # Allow Utf8 as an alias for String Utf8 = String @@ -395,50 +445,88 @@ class String(DataType): class Binary(DataType): """Binary type.""" + __slots__ = () + class Date(TemporalType): - """Calendar date type.""" + """ + Data type representing a calendar date. + + Notes + ----- + The underlying representation of this type is a 32-bit signed integer. + The integer indicates the number of days since the Unix epoch (1970-01-01). + The number can be negative to indicate dates before the epoch. + """ + + __slots__ = () class Time(TemporalType): - """Time of day type.""" + """ + Data type representing the time of day. + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates the number of nanoseconds since midnight. + """ + + __slots__ = () class Datetime(TemporalType): - """Calendar date and time type.""" + """ + Data type representing a calendar date and time of day. + + Parameters + ---------- + time_unit : {'us', 'ns', 'ms'} + Unit of time. Defaults to `'us'` (microseconds). + time_zone + Time zone string, as defined in zoneinfo (to see valid strings run + `import zoneinfo; zoneinfo.available_timezones()` for a full list). + When using to match dtypes, can use "*" to check for Datetime columns + that have any timezone. + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates the number of time units since the Unix epoch + (1970-01-01 00:00:00). The number can be negative to indicate datetimes before the + epoch. + """ time_unit: TimeUnit | None = None time_zone: str | None = None def __init__( - self, time_unit: TimeUnit | None = "us", time_zone: str | timezone | None = None + self, time_unit: TimeUnit = "us", time_zone: str | timezone | None = None ): - """ - Calendar date and time type. - - Parameters - ---------- - time_unit : {'us', 'ns', 'ms'} - Unit of time / precision. - time_zone - Time zone string, as defined in zoneinfo (to see valid strings run - `import zoneinfo; zoneinfo.available_timezones()` for a full list). - When using to match dtypes, can use "*" to check for Datetime columns - that have any timezone. - """ - if isinstance(time_zone, timezone): - time_zone = str(time_zone) - - self.time_unit = time_unit or "us" - self.time_zone = time_zone + if time_unit is None: + from polars.utils.deprecation import issue_deprecation_warning + + issue_deprecation_warning( + "Passing `time_unit=None` to the Datetime constructor is deprecated." + " Either avoid passing a time unit to use the default value ('us')," + " or pass a valid time unit instead ('ms', 'us', 'ns').", + version="0.20.11", + ) + time_unit = "us" - if self.time_unit not in ("ms", "us", "ns"): + if time_unit not in ("ms", "us", "ns"): msg = ( "invalid `time_unit`" - f"\n\nExpected one of {{'ns','us','ms'}}, got {self.time_unit!r}." + f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}." ) raise ValueError(msg) + if isinstance(time_zone, timezone): + time_zone = str(time_zone) + + self.time_unit = time_unit + self.time_zone = time_zone + def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] # allow comparing object instances to class if type(other) is DataTypeClass and issubclass(other, Datetime): @@ -461,27 +549,33 @@ def __repr__(self) -> str: class Duration(TemporalType): - """Time duration/delta type.""" + """ + Data type representing a time duration. + + Parameters + ---------- + time_unit : {'us', 'ns', 'ms'} + Unit of time. Defaults to `'us'` (microseconds). + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates an amount of time units and can be negative to indicate + negative time offsets. + """ time_unit: TimeUnit | None = None def __init__(self, time_unit: TimeUnit = "us"): - """ - Time duration/delta type. - - Parameters - ---------- - time_unit : {'us', 'ns', 'ms'} - Unit of time. - """ - self.time_unit = time_unit - if self.time_unit not in ("ms", "us", "ns"): + if time_unit not in ("ms", "us", "ns"): msg = ( "invalid `time_unit`" - f"\n\nExpected one of {{'ns','us','ms'}}, got {self.time_unit!r}." + f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}." ) raise ValueError(msg) + self.time_unit = time_unit + def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] # allow comparing object instances to class if type(other) is DataTypeClass and issubclass(other, Duration): @@ -505,11 +599,12 @@ class Categorical(DataType): Parameters ---------- - ordering : {'lexical', 'physical'} - Ordering by order of appearance (physical, default) - or string value (lexical). + ordering : {'lexical', 'physical'} + Ordering by order of appearance (`'physical'`, default) + or string value (`'lexical'`). """ + __slots__ = ("ordering",) ordering: CategoricalOrdering | None def __init__( @@ -542,19 +637,17 @@ class Enum(DataType): This functionality is considered **unstable**. It is a work-in-progress feature and may not always work as expected. It may be changed at any point without it being considered a breaking change. + + Parameters + ---------- + categories + The categories in the dataset. Categories must be strings. """ + __slots__ = ("categories",) categories: Series def __init__(self, categories: Series | Iterable[str]): - """ - A fixed set categorical encoding of a set of strings. - - Parameters - ---------- - categories - Valid categories in the dataset. - """ # Issuing the warning on `__init__` does not trigger when the class is used # without being instantiated, but it's better than nothing from polars.utils.unstable import issue_unstable_warning @@ -604,50 +697,55 @@ def __repr__(self) -> str: class Object(DataType): - """Type for wrapping arbitrary Python objects.""" + """Data type for wrapping arbitrary Python objects.""" + + __slots__ = () class Null(DataType): - """Type representing Null / None values.""" + """Data type representing null values.""" + + __slots__ = () class Unknown(DataType): - """Type representing Datatype values that could not be determined statically.""" + """Type representing DataType values that could not be determined statically.""" + + __slots__ = () class List(NestedType): - """Variable length list type.""" + """ + Variable length list type. + + Parameters + ---------- + inner + The `DataType` of the values within each list. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "integer_lists": [[1, 2], [3, 4]], + ... "float_lists": [[1.0, 2.0], [3.0, 4.0]], + ... } + ... ) + >>> df + shape: (2, 2) + ┌───────────────┬─────────────┐ + │ integer_lists ┆ float_lists │ + │ --- ┆ --- │ + │ list[i64] ┆ list[f64] │ + ╞═══════════════╪═════════════╡ + │ [1, 2] ┆ [1.0, 2.0] │ + │ [3, 4] ┆ [3.0, 4.0] │ + └───────────────┴─────────────┘ + """ inner: PolarsDataType | None = None def __init__(self, inner: PolarsDataType | PythonDataType): - """ - Variable length list type. - - Parameters - ---------- - inner - The `DataType` of the values within each list. - - Examples - -------- - >>> df = pl.DataFrame( - ... { - ... "integer_lists": [[1, 2], [3, 4]], - ... "float_lists": [[1.0, 2.0], [3.0, 4.0]], - ... } - ... ) - >>> df - shape: (2, 2) - ┌───────────────┬─────────────┐ - │ integer_lists ┆ float_lists │ - │ --- ┆ --- │ - │ list[i64] ┆ list[f64] │ - ╞═══════════════╪═════════════╡ - │ [1, 2] ┆ [1.0, 2.0] │ - │ [3, 4] ┆ [3.0, 4.0] │ - └───────────────┴─────────────┘ - """ self.inner = polars.datatypes.py_type_to_dtype(inner) def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] @@ -677,33 +775,32 @@ def __repr__(self) -> str: class Array(NestedType): - """Fixed length list type.""" + """ + Fixed length list type. + + Parameters + ---------- + inner + The `DataType` of the values within each array. + width + The length of the arrays. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s + shape: (2,) + Series: 'a' [array[i64, 2]] + [ + [1, 2] + [4, 3] + ] + """ inner: PolarsDataType | None = None width: int def __init__(self, inner: PolarsDataType | PythonDataType, width: int): - """ - Fixed length list type. - - Parameters - ---------- - inner - The `DataType` of the values within each array. - width - The length of the arrays. - - Examples - -------- - >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) - >>> s - shape: (2,) - Series: 'a' [array[i64, 2]] - [ - [1, 2] - [4, 3] - ] - """ self.inner = polars.datatypes.py_type_to_dtype(inner) self.width = width @@ -736,19 +833,22 @@ def __repr__(self) -> str: class Field: - """Definition of a single field within a `Struct` DataType.""" + """ + Definition of a single field within a `Struct` DataType. - def __init__(self, name: str, dtype: PolarsDataType): - """ - Definition of a single field within a `Struct` DataType. + Parameters + ---------- + name + The name of the field within its parent `Struct`. + dtype + The `DataType` of the field's values. + """ - Parameters - ---------- - name - The name of the field within its parent `Struct` - dtype - The `DataType` of the field's values - """ + __slots__ = ("name", "dtype") + name: str + dtype: PolarsDataType + + def __init__(self, name: str, dtype: PolarsDataType): self.name = name self.dtype = polars.datatypes.py_type_to_dtype(dtype) @@ -764,49 +864,47 @@ def __repr__(self) -> str: class Struct(NestedType): - """Struct composite type.""" + """ + Struct composite type. + + Parameters + ---------- + fields + The fields that make up the struct. Can be either a sequence of Field + objects or a mapping of column names to data types. + + Examples + -------- + Initialize using a dictionary: + + >>> dtype = pl.Struct({"a": pl.Int8, "b": pl.List(pl.String)}) + >>> dtype + Struct({'a': Int8, 'b': List(String)}) + + Initialize using a list of Field objects: + + >>> dtype = pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.List(pl.String))]) + >>> dtype + Struct({'a': Int8, 'b': List(String)}) + + When initializing a Series, Polars can infer a struct data type from the data. + + >>> s = pl.Series([{"a": 1, "b": ["x", "y"]}, {"a": 2, "b": ["z"]}]) + >>> s + shape: (2,) + Series: '' [struct[2]] + [ + {1,["x", "y"]} + {2,["z"]} + ] + >>> s.dtype + Struct({'a': Int64, 'b': List(String)}) + """ + __slots__ = ("fields",) fields: list[Field] def __init__(self, fields: Sequence[Field] | SchemaDict): - """ - Struct composite type. - - Parameters - ---------- - fields - The fields that make up the struct. Can be either a sequence of Field - objects or a mapping of column names to data types. - - Examples - -------- - Initialize using a dictionary: - - >>> dtype = pl.Struct({"a": pl.Int8, "b": pl.List(pl.String)}) - >>> dtype - Struct({'a': Int8, 'b': List(String)}) - - Initialize using a list of Field objects: - - >>> dtype = pl.Struct( - ... [pl.Field("a", pl.Int8), pl.Field("b", pl.List(pl.String))] - ... ) - >>> dtype - Struct({'a': Int8, 'b': List(String)}) - - When initializing a Series, Polars can infer a struct data type from the data. - - >>> s = pl.Series([{"a": 1, "b": ["x", "y"]}, {"a": 2, "b": ["z"]}]) - >>> s - shape: (2,) - Series: '' [struct[2]] - [ - {1,["x", "y"]} - {2,["z"]} - ] - >>> s.dtype - Struct({'a': Int64, 'b': List(String)}) - """ if isinstance(fields, Mapping): self.fields = [Field(name, dtype) for name, dtype in fields.items()] else: diff --git a/py-polars/polars/datatypes/constructor.py b/py-polars/polars/datatypes/constructor.py index a164ba4f11d5..d7d1e23eab7c 100644 --- a/py-polars/polars/datatypes/constructor.py +++ b/py-polars/polars/datatypes/constructor.py @@ -135,7 +135,9 @@ def numpy_type_to_constructor( return _NUMPY_TYPE_TO_CONSTRUCTOR[dtype] # type:ignore[index] except KeyError: if len(values) > 0: - first_non_nan = next((v for v in values if v == v), None) + first_non_nan = next( + (v for v in values if isinstance(v, np.ndarray) or v == v), None + ) if isinstance(first_non_nan, str): return PySeries.new_str if isinstance(first_non_nan, bytes): diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 9e13e82e4896..de84dcaaedd1 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -140,10 +140,10 @@ def _map_py_type_to_dtype( def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool: - """Indicate whether the given input is a Polars dtype, or dtype specialisation.""" + """Indicate whether the given input is a Polars dtype, or dtype specialization.""" try: if dtype == Unknown: - # does not represent a realisable dtype, so ignore by default + # does not represent a realizable dtype, so ignore by default return include_unknown else: return isinstance(dtype, (DataType, DataTypeClass)) @@ -246,12 +246,12 @@ def DTYPE_TO_CTYPE(self) -> dict[PolarsDataType, Any]: Int8: ctypes.c_int8, Int16: ctypes.c_int16, Int32: ctypes.c_int32, - Date: ctypes.c_int32, Int64: ctypes.c_int64, Float32: ctypes.c_float, Float64: ctypes.c_double, Datetime: ctypes.c_int64, Duration: ctypes.c_int64, + Date: ctypes.c_int32, Time: ctypes.c_int64, } @@ -298,6 +298,8 @@ def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataTy ("u", 8): UInt64, ("f", 4): Float32, ("f", 8): Float64, + ("m", 8): Duration, + ("M", 8): Datetime, } @property @@ -469,6 +471,8 @@ def numpy_char_code_to_dtype(dtype_char: str) -> PolarsDataType: dtype = np.dtype(dtype_char) if dtype.kind == "U": return String + elif dtype.kind == "S": + return Binary try: return DataTypeMappings.NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE[ (dtype.kind, dtype.itemsize) diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index d987b36d6ac4..1cc61eb4609c 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -8,7 +8,6 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast -_DATAFRAME_API_COMPAT_AVAILABLE = True _DELTALAKE_AVAILABLE = True _FSSPEC_AVAILABLE = True _GEVENT_AVAILABLE = True @@ -150,7 +149,6 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import pickle import subprocess - import dataframe_api_compat import deltalake import fsspec import gevent @@ -175,9 +173,6 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: subprocess, _ = _lazy_import("subprocess") # heavy/optional third party libs - dataframe_api_compat, _DATAFRAME_API_COMPAT_AVAILABLE = _lazy_import( - "dataframe_api_compat" - ) deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake") fsspec, _FSSPEC_AVAILABLE = _lazy_import("fsspec") hvplot, _HVPLOT_AVAILABLE = _lazy_import("hvplot") @@ -281,7 +276,6 @@ def import_optional( "pickle", "subprocess", # lazy-load third party libs - "dataframe_api_compat", "deltalake", "fsspec", "gevent", diff --git a/py-polars/polars/exceptions.py b/py-polars/polars/exceptions.py index 702fb938f7d8..1fe2b9db7ea6 100644 --- a/py-polars/polars/exceptions.py +++ b/py-polars/polars/exceptions.py @@ -121,6 +121,10 @@ class ArrowError(Exception): """Deprecated: will be removed.""" +class CustomUFuncWarning(PolarsWarning): # type: ignore[misc] + """Warning issued when a custom ufunc is handled differently than numpy ufunc would.""" # noqa: W505 + + __all__ = [ "ArrowError", "ColumnNotFoundError", diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 368b4fb4cfe9..b228b7b562b7 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Sequence from polars.utils._parse_expr_input import parse_as_expression from polars.utils._wrap import wrap_expr @@ -89,6 +89,75 @@ def sum(self) -> Expr: """ return wrap_expr(self._pyexpr.arr_sum()) + def std(self, ddof: int = 1) -> Expr: + """ + Compute the std of the values of the sub-arrays. + + Examples + -------- + >>> df = pl.DataFrame( + ... data={"a": [[1, 2], [4, 3]]}, + ... schema={"a": pl.Array(pl.Int64, 2)}, + ... ) + >>> df.select(pl.col("a").arr.std()) + shape: (2, 1) + ┌──────────┐ + │ a │ + │ --- │ + │ f64 │ + ╞══════════╡ + │ 0.707107 │ + │ 0.707107 │ + └──────────┘ + """ + return wrap_expr(self._pyexpr.arr_std(ddof)) + + def var(self, ddof: int = 1) -> Expr: + """ + Compute the var of the values of the sub-arrays. + + Examples + -------- + >>> df = pl.DataFrame( + ... data={"a": [[1, 2], [4, 3]]}, + ... schema={"a": pl.Array(pl.Int64, 2)}, + ... ) + >>> df.select(pl.col("a").arr.var()) + shape: (2, 1) + ┌─────┐ + │ a │ + │ --- │ + │ f64 │ + ╞═════╡ + │ 0.5 │ + │ 0.5 │ + └─────┘ + """ + return wrap_expr(self._pyexpr.arr_var(ddof)) + + def median(self) -> Expr: + """ + Compute the median of the values of the sub-arrays. + + Examples + -------- + >>> df = pl.DataFrame( + ... data={"a": [[1, 2], [4, 3]]}, + ... schema={"a": pl.Array(pl.Int64, 2)}, + ... ) + >>> df.select(pl.col("a").arr.median()) + shape: (2, 1) + ┌─────┐ + │ a │ + │ --- │ + │ f64 │ + ╞═════╡ + │ 1.5 │ + │ 3.5 │ + └─────┘ + """ + return wrap_expr(self._pyexpr.arr_median()) + def unique(self, *, maintain_order: bool = False) -> Expr: """ Get the unique/distinct values in the array. @@ -574,3 +643,110 @@ def count_matches(self, element: IntoExpr) -> Expr: """ element = parse_as_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.arr_count_matches(element)) + + def to_struct( + self, fields: Sequence[str] | Callable[[int], str] | None = None + ) -> Expr: + """ + Convert the Series of type `Array` to a Series of type `Struct`. + + Parameters + ---------- + fields + If the name and number of the desired fields is known in advance + a list of field names can be given, which will be assigned by index. + Otherwise, to dynamically assign field names, a custom function can be + used; if neither are set, fields will be `field_0, field_1 .. field_n`. + + Examples + -------- + Convert array to struct with default field name assignment: + + >>> df = pl.DataFrame( + ... {"n": [[0, 1, 2], [3, 4, 5]]}, schema={"n": pl.Array(pl.Int8, 3)} + ... ) + >>> df.with_columns(struct=pl.col("n").arr.to_struct()) + shape: (2, 2) + ┌──────────────┬───────────┐ + │ n ┆ struct │ + │ --- ┆ --- │ + │ array[i8, 3] ┆ struct[3] │ + ╞══════════════╪═══════════╡ + │ [0, 1, 2] ┆ {0,1,2} │ + │ [3, 4, 5] ┆ {3,4,5} │ + └──────────────┴───────────┘ + + Convert array to struct with field name assignment by function/index: + + >>> df = pl.DataFrame( + ... {"n": [[0, 1, 2], [3, 4, 5]]}, schema={"n": pl.Array(pl.Int8, 3)} + ... ) + >>> df.select(pl.col("n").arr.to_struct(fields=lambda idx: f"n{idx}")).rows( + ... named=True + ... ) + [{'n': {'n0': 0, 'n1': 1, 'n2': 2}}, {'n': {'n0': 3, 'n1': 4, 'n2': 5}}] + + Convert array to struct with field name assignment by + index from a list of names: + + >>> df.select(pl.col("n").arr.to_struct(fields=["c1", "c2", "c3"])).rows( + ... named=True + ... ) + [{'n': {'c1': 0, 'c2': 1, 'c3': 2}}, {'n': {'c1': 3, 'c2': 4, 'c3': 5}}] + """ + if isinstance(fields, Sequence): + field_names = list(fields) + pyexpr = self._pyexpr.arr_to_struct(None) + return wrap_expr(pyexpr).struct.rename_fields(field_names) + else: + pyexpr = self._pyexpr.arr_to_struct(fields) + return wrap_expr(pyexpr) + + def shift(self, n: int | IntoExprColumn = 1) -> Expr: + """ + Shift array values by the given number of indices. + + Parameters + ---------- + n + Number of indices to shift forward. If a negative value is passed, values + are shifted in the opposite direction instead. + + Notes + ----- + This method is similar to the `LAG` operation in SQL when the value for `n` + is positive. With a negative value for `n`, it is similar to `LEAD`. + + Examples + -------- + By default, array values are shifted forward by one index. + + >>> df = pl.DataFrame( + ... {"a": [[1, 2, 3], [4, 5, 6]]}, schema={"a": pl.Array(pl.Int64, 3)} + ... ) + >>> df.with_columns(shift=pl.col("a").arr.shift()) + shape: (2, 2) + ┌───────────────┬───────────────┐ + │ a ┆ shift │ + │ --- ┆ --- │ + │ array[i64, 3] ┆ array[i64, 3] │ + ╞═══════════════╪═══════════════╡ + │ [1, 2, 3] ┆ [null, 1, 2] │ + │ [4, 5, 6] ┆ [null, 4, 5] │ + └───────────────┴───────────────┘ + + Pass a negative value to shift in the opposite direction instead. + + >>> df.with_columns(shift=pl.col("a").arr.shift(-2)) + shape: (2, 2) + ┌───────────────┬─────────────────┐ + │ a ┆ shift │ + │ --- ┆ --- │ + │ array[i64, 3] ┆ array[i64, 3] │ + ╞═══════════════╪═════════════════╡ + │ [1, 2, 3] ┆ [3, null, null] │ + │ [4, 5, 6] ┆ [6, null, null] │ + └───────────────┴─────────────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.arr_shift(n)) diff --git a/py-polars/polars/expr/binary.py b/py-polars/polars/expr/binary.py index 461c188822a7..ed1af57c8b2d 100644 --- a/py-polars/polars/expr/binary.py +++ b/py-polars/polars/expr/binary.py @@ -162,8 +162,8 @@ def starts_with(self, prefix: IntoExpr) -> Expr: return wrap_expr(self._pyexpr.bin_starts_with(prefix)) def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -172,6 +172,33 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Expr + Expression of data type :class:`String`. + + Examples + -------- + >>> colors = pl.DataFrame( + ... { + ... "name": ["black", "yellow", "blue"], + ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], + ... } + ... ) + >>> colors.with_columns( + ... pl.col("code").bin.encode("hex").alias("encoded"), + ... ) + shape: (3, 3) + ┌────────┬─────────────────┬─────────┐ + │ name ┆ code ┆ encoded │ + │ --- ┆ --- ┆ --- │ + │ str ┆ binary ┆ str │ + ╞════════╪═════════════════╪═════════╡ + │ black ┆ b"\x00\x00\x00" ┆ 000000 │ + │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ + │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ + └────────┴─────────────────┴─────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.bin_hex_decode(strict)) @@ -193,30 +220,29 @@ def encode(self, encoding: TransferEncoding) -> Expr: Returns ------- Expr - Expression of data type :class:`String` with values encoded using provided - encoding. + Expression of data type :class:`String`. Examples -------- >>> colors = pl.DataFrame( ... { - ... "name": ["black", "yellow", "blue"], + ... "color": ["black", "yellow", "blue"], ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], ... } ... ) >>> colors.with_columns( - ... pl.col("code").bin.encode("hex").alias("code_encoded_hex"), + ... pl.col("code").bin.encode("hex").alias("encoded"), ... ) shape: (3, 3) - ┌────────┬─────────────────┬──────────────────┐ - │ name ┆ code ┆ code_encoded_hex │ - │ --- ┆ --- ┆ --- │ - │ str ┆ binary ┆ str │ - ╞════════╪═════════════════╪══════════════════╡ - │ black ┆ b"\x00\x00\x00" ┆ 000000 │ - │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ - │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ - └────────┴─────────────────┴──────────────────┘ + ┌────────┬─────────────────┬─────────┐ + │ color ┆ code ┆ encoded │ + │ --- ┆ --- ┆ --- │ + │ str ┆ binary ┆ str │ + ╞════════╪═════════════════╪═════════╡ + │ black ┆ b"\x00\x00\x00" ┆ 000000 │ + │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ + │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ + └────────┴─────────────────┴─────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.bin_hex_encode()) diff --git a/py-polars/polars/expr/categorical.py b/py-polars/polars/expr/categorical.py index f60a3bc48f7a..89ecef5188ea 100644 --- a/py-polars/polars/expr/categorical.py +++ b/py-polars/polars/expr/categorical.py @@ -34,7 +34,7 @@ def set_ordering(self, ordering: CategoricalOrdering) -> Expr: Ordering type: - 'physical' -> Use the physical representation of the categories to - determine the order (default). + determine the order (default). - 'lexical' -> Use the string values to determine the ordering. """ return wrap_expr(self._pyexpr.cat_set_ordering(ordering)) diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index 3a73c83779e5..5debd23daa5b 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -244,7 +244,7 @@ def round( - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime - .. deprecated: 0.19.3 + .. deprecated:: 0.19.3 This is now auto-inferred, you can safely remove this argument. Returns @@ -1548,6 +1548,11 @@ def convert_time_zone(self, time_zone: str) -> Expr: time_zone Time zone for the `Datetime` expression. + Notes + ----- + If converting from a time-zone-naive datetime, then conversion will happen + as if converting from UTC, regardless of your system's time zone. + Examples -------- >>> from datetime import datetime diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index de83d8da4da7..dcfa733d378c 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3,10 +3,11 @@ import contextlib import math import operator -import os import warnings from datetime import timedelta -from functools import partial, reduce +from functools import reduce +from io import BytesIO, StringIO +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -31,7 +32,7 @@ ) from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np -from polars.exceptions import PolarsInefficientMapWarning +from polars.exceptions import CustomUFuncWarning, PolarsInefficientMapWarning from polars.expr.array import ExprArrayNameSpace from polars.expr.binary import ExprBinaryNameSpace from polars.expr.categorical import ExprCatNameSpace @@ -41,6 +42,7 @@ from polars.expr.name import ExprNameNameSpace from polars.expr.string import ExprStringNameSpace from polars.expr.struct import ExprStructNameSpace +from polars.meta import thread_pool_size from polars.utils._parse_expr_input import ( parse_as_expression, parse_as_list_of_expressions, @@ -55,23 +57,25 @@ deprecate_saturating, issue_deprecation_warning, ) -from polars.utils.meta import threadpool_size from polars.utils.unstable import issue_unstable_warning, unstable from polars.utils.various import ( + BUILDING_SPHINX_DOCS, + find_stacklevel, no_default, + normalize_filepath, sphinx_accessor, warn_null_comparison, ) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import arg_where as py_arg_where - from polars.polars import reduce as pyreduce with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyExpr if TYPE_CHECKING: import sys + from io import IOBase from polars import DataFrame, LazyFrame, Series from polars.type_aliases import ( @@ -84,7 +88,6 @@ NullBehavior, NumericLiteral, PolarsDataType, - PythonLiteral, RankMethod, RollingInterpolationMethod, SearchSortedSide, @@ -103,14 +106,15 @@ T = TypeVar("T") P = ParamSpec("P") -elif os.getenv("BUILDING_SPHINX_DOCS"): +elif BUILDING_SPHINX_DOCS: property = sphinx_accessor class Expr: """Expressions that can be used in various contexts.""" - _pyexpr: PyExpr = None + __slots__ = ("_pyexpr",) + _pyexpr: PyExpr _accessors: ClassVar[set[str]] = { "arr", "cat", @@ -143,8 +147,13 @@ def __str__(self) -> str: def __bool__(self) -> NoReturn: msg = ( "the truth value of an Expr is ambiguous" - "\n\nHint: use '&' or '|' to logically combine Expr, not 'and'/'or', and" - " use `x.is_in([y,z])` instead of `x in [y,z]` to check membership." + "\n\n" + "You probably got here by using a Python standard library function instead " + "of the native expressions API.\n" + "Here are some things you might want to try:\n" + "- instead of `pl.col('a') and pl.col('b')`, use `pl.col('a') & pl.col('b')`\n" + "- instead of `pl.col('a') in [y, z]`, use `pl.col('a').is_in([y, z])`\n" + "- instead of `max(pl.col('a'), pl.col('b'))`, use `pl.max_horizontal(pl.col('a'), pl.col('b'))`\n" ) raise TypeError(msg) @@ -162,11 +171,11 @@ def __radd__(self, other: IntoExpr) -> Self: def __and__(self, other: IntoExprColumn | int | bool) -> Self: other = parse_as_expression(other) - return self._from_pyexpr(self._pyexpr._and(other)) + return self._from_pyexpr(self._pyexpr.and_(other)) def __rand__(self, other: IntoExprColumn | int | bool) -> Self: other_expr = parse_as_expression(other) - return self._from_pyexpr(other_expr._and(self._pyexpr)) + return self._from_pyexpr(other_expr.and_(self._pyexpr)) def __eq__(self, other: IntoExpr) -> Self: # type: ignore[override] warn_null_comparison(other) @@ -230,11 +239,11 @@ def __neg__(self) -> Self: def __or__(self, other: IntoExprColumn | int | bool) -> Self: other = parse_as_expression(other) - return self._from_pyexpr(self._pyexpr._or(other)) + return self._from_pyexpr(self._pyexpr.or_(other)) def __ror__(self, other: IntoExprColumn | int | bool) -> Self: other_expr = parse_as_expression(other) - return self._from_pyexpr(other_expr._or(self._pyexpr)) + return self._from_pyexpr(other_expr.or_(self._pyexpr)) def __pos__(self) -> Expr: return self @@ -265,11 +274,11 @@ def __rtruediv__(self, other: IntoExpr) -> Self: def __xor__(self, other: IntoExprColumn | int | bool) -> Self: other = parse_as_expression(other) - return self._from_pyexpr(self._pyexpr._xor(other)) + return self._from_pyexpr(self._pyexpr.xor_(other)) def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: other_expr = parse_as_expression(other) - return self._from_pyexpr(other_expr._xor(self._pyexpr)) + return self._from_pyexpr(other_expr.xor_(self._pyexpr)) def __getstate__(self) -> bytes: return self._pyexpr.__getstate__() @@ -282,36 +291,79 @@ def __array_ufunc__( self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any ) -> Self: """Numpy universal functions.""" + if method != "__call__": + msg = f"Only call is implemented not {method}" + raise NotImplementedError(msg) + is_custom_ufunc = ufunc.__class__ != np.ufunc num_expr = sum(isinstance(inp, Expr) for inp in inputs) - if num_expr > 1: - if num_expr < len(inputs): - msg = ( - "NumPy ufunc with more than one expression can only be used" - " if all non-expression inputs are provided as keyword arguments only" - ) - raise ValueError(msg) - - exprs = parse_as_list_of_expressions(inputs) - return self._from_pyexpr(pyreduce(partial(ufunc, **kwargs), exprs)) + exprs = [ + (inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i) + for i, inp in enumerate(inputs) + ] + if num_expr == 1: + root_expr = next(expr[0] for expr in exprs if expr[1] == Expr) + else: + root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr) def function(s: Series) -> Series: # pragma: no cover - args = [inp if not isinstance(inp, Expr) else s for inp in inputs] + args = [] + for i, expr in enumerate(exprs): + if expr[1] == Expr and num_expr > 1: + args.append(s.struct[i]) + elif expr[1] == Expr: + args.append(s) + else: + args.append(expr[0]) return ufunc(*args, **kwargs) - return self.map_batches(function) + if is_custom_ufunc is True: + msg = ( + "Native numpy ufuncs are dispatched using `map_batches(ufunc, is_elementwise=True)` which " + "is safe for native Numpy and Scipy ufuncs but custom ufuncs in a group_by " + "context won't be properly grouped. Custom ufuncs are dispatched with is_elementwise=False. " + f"If {ufunc.__name__} needs elementwise then please use map_batches directly." + ) + warnings.warn( + msg, + CustomUFuncWarning, + stacklevel=find_stacklevel(), + ) + return root_expr.map_batches( + function, is_elementwise=False + ).meta.undo_aliases() + return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases() @classmethod - def from_json(cls, value: str) -> Self: + def deserialize(cls, source: str | Path | IOBase) -> Self: """ - Read an expression from a JSON encoded string to construct an Expression. + Read an expression from a JSON file. Parameters ---------- - value - JSON encoded string value + source + Path to a file or a file-like object (by file-like object, we refer to + objects that have a `read()` method, such as a file handler (e.g. + via builtin `open` function) or `BytesIO`). + + See Also + -------- + Expr.meta.serialize + + Examples + -------- + >>> from io import StringIO + >>> expr = pl.col("foo").sum().over("bar") + >>> json = expr.meta.serialize() + >>> pl.Expr.deserialize(StringIO(json)) # doctest: +ELLIPSIS + """ + if isinstance(source, StringIO): + source = BytesIO(source.getvalue().encode()) + elif isinstance(source, (str, Path)): + source = normalize_filepath(source) + expr = cls.__new__(cls) - expr._pyexpr = PyExpr.meta_read_json(value) + expr._pyexpr = PyExpr.deserialize(source) return expr def to_physical(self) -> Self: @@ -3271,7 +3323,7 @@ def rolling( check_sorted: bool = True, ) -> Self: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. If you have a time series ``, then by default the windows created will be @@ -3311,11 +3363,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 - Parameters ---------- index_column @@ -3323,8 +3370,8 @@ def rolling( Often of type Date/Datetime. This column must be sorted in ascending order. In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -4033,6 +4080,8 @@ def map_batches( Lambda/function to apply. return_dtype Dtype of the output Series. + If not set, the dtype will be inferred based on the first non-null value + that is returned by the function. is_elementwise If set to true this can run in the streaming engine, but may yield incorrect results in group-by. Ensure you know what you are doing! @@ -4157,7 +4206,8 @@ def map_elements( Lambda/function to map. return_dtype Dtype of the output Series. - If not set, the dtype will be `pl.Unknown`. + If not set, the dtype will be inferred based on the first non-null value + that is returned by the function. skip_nulls Don't map the function over values that contain nulls (this is faster). pass_name @@ -4352,7 +4402,7 @@ def get_lazy_promise(df: DataFrame) -> LazyFrame: if x.len() == 0: return get_lazy_promise(df).collect().to_series() - n_threads = threadpool_size() + n_threads = thread_pool_size() chunk_size = x.len() // n_threads remainder = x.len() % n_threads if chunk_size == 0: @@ -4380,7 +4430,8 @@ def get_lazy_promise(df: DataFrame) -> LazyFrame: wrap_threading, agg_list=True, return_dtype=return_dtype ) else: - ValueError(f"Strategy {strategy} is not supported.") + msg = f"strategy {strategy!r} is not supported" + raise ValueError(msg) def flatten(self) -> Self: """ @@ -5684,16 +5735,13 @@ def rolling_min( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -5899,16 +5947,13 @@ def rolling_max( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6139,16 +6184,13 @@ def rolling_mean( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6389,16 +6431,13 @@ def rolling_sum( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6634,8 +6673,6 @@ def rolling_std( - ... - [t_n - window_size, t_n) - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. Parameters ---------- @@ -6876,16 +6913,13 @@ def rolling_var( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -7133,8 +7167,6 @@ def rolling_median( - ... - [t_n - window_size, t_n) - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. Parameters ---------- @@ -7287,16 +7319,13 @@ def rolling_quantile( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -7934,9 +7963,9 @@ def clip( └──────┴──────┘ """ if lower_bound is not None: - lower_bound = parse_as_expression(lower_bound, str_as_lit=True) + lower_bound = parse_as_expression(lower_bound) if upper_bound is not None: - upper_bound = parse_as_expression(upper_bound, str_as_lit=True) + upper_bound = parse_as_expression(upper_bound) return self._from_pyexpr(self._pyexpr.clip(lower_bound, upper_bound)) def lower_bound(self) -> Self: @@ -8526,7 +8555,7 @@ def ewm_mean( *, adjust: bool = True, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving average. @@ -8555,7 +8584,7 @@ def ewm_mean( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8569,7 +8598,7 @@ def ewm_mean( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8577,7 +8606,7 @@ def ewm_mean( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8587,7 +8616,7 @@ def ewm_mean( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_mean(com=1)) + >>> df.select(pl.col("a").ewm_mean(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8599,6 +8628,16 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_mean(alpha, adjust, min_periods, ignore_nulls) @@ -8615,7 +8654,7 @@ def ewm_std( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving standard deviation. @@ -8644,7 +8683,7 @@ def ewm_std( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8661,7 +8700,7 @@ def ewm_std( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8669,7 +8708,7 @@ def ewm_std( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8679,7 +8718,7 @@ def ewm_std( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_std(com=1)) + >>> df.select(pl.col("a").ewm_std(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8691,6 +8730,16 @@ def ewm_std( │ 0.963624 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_std(alpha, adjust, bias, min_periods, ignore_nulls) @@ -8707,7 +8756,7 @@ def ewm_var( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving variance. @@ -8736,7 +8785,7 @@ def ewm_var( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8753,7 +8802,7 @@ def ewm_var( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8761,7 +8810,7 @@ def ewm_var( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8771,7 +8820,7 @@ def ewm_var( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_var(com=1)) + >>> df.select(pl.col("a").ewm_var(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8783,19 +8832,29 @@ def ewm_var( │ 0.928571 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_var(alpha, adjust, bias, min_periods, ignore_nulls) ) - def extend_constant(self, value: PythonLiteral | None, n: int) -> Self: + def extend_constant(self, value: IntoExpr, n: int | IntoExprColumn) -> Self: """ Extremely fast method for extending the Series with 'n' copies of a value. Parameters ---------- value - A constant literal value (not an expression) with which to extend the + A constant literal value or a unit expressioin with which to extend the expression result Series; can pass None to extend with nulls. n The number of additional values that will be added. @@ -8817,10 +8876,8 @@ def extend_constant(self, value: PythonLiteral | None, n: int) -> Self: │ 99 │ └────────┘ """ - if isinstance(value, Expr): - msg = f"`value` must be a supported literal; found {value!r}" - raise TypeError(msg) - + value = parse_as_expression(value, str_as_lit=True) + n = parse_as_expression(n) return self._from_pyexpr(self._pyexpr.extend_constant(value, n)) @deprecate_renamed_parameter("multithreaded", "parallel", version="0.19.0") @@ -9428,10 +9485,10 @@ def apply( - 'thread_local': run the python function on a single thread. - 'threading': run the python function on separate threads. Use with - care as this can slow performance. This might only speed up - your code if the amount of work per element is significant - and the python function releases the GIL (e.g. via calling - a c function) + care as this can slow performance. This might only speed up + your code if the amount of work per element is significant + and the python function releases the GIL (e.g. via calling + a c function) """ return self.map_elements( function, @@ -9816,6 +9873,29 @@ def map_dict( """ return self.replace(mapping, default=default, return_dtype=return_dtype) + @classmethod + def from_json(cls, value: str) -> Self: + """ + Read an expression from a JSON encoded string to construct an Expression. + + .. deprecated:: 0.20.11 + This method has been renamed to :meth:`deserialize`. + Note that the new method operates on file-like inputs rather than strings. + Enclose your input in `io.StringIO` to keep the same behavior. + + Parameters + ---------- + value + JSON encoded string value + """ + issue_deprecation_warning( + "`Expr.from_json` is deprecated. It has been renamed to `Expr.deserialize`." + " Note that the new method operates on file-like inputs rather than strings." + " Enclose your input in `io.StringIO` to keep the same behavior.", + version="0.20.11", + ) + return cls.deserialize(StringIO(value)) + @property def bin(self) -> ExprBinaryNameSpace: """ diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index eb9b9ca5e786..71139e65cb66 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -275,6 +275,80 @@ def mean(self) -> Expr: """ return wrap_expr(self._pyexpr.list_mean()) + def median(self) -> Expr: + """ + Compute the median value of the lists in the array. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[-1, 0, 1], [1, 10]]}) + >>> df.with_columns(pl.col("values").list.median().alias("median")) + shape: (2, 2) + ┌────────────┬────────┐ + │ values ┆ median │ + │ --- ┆ --- │ + │ list[i64] ┆ f64 │ + ╞════════════╪════════╡ + │ [-1, 0, 1] ┆ 0.0 │ + │ [1, 10] ┆ 5.5 │ + └────────────┴────────┘ + """ + return wrap_expr(self._pyexpr.list_median()) + + def std(self, ddof: int = 1) -> Expr: + """ + Compute the std value of the lists in the array. + + Parameters + ---------- + ddof + “Delta Degrees of Freedom”: the divisor used in the calculation is N - ddof, + where N represents the number of elements. + By default ddof is 1. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[-1, 0, 1], [1, 10]]}) + >>> df.with_columns(pl.col("values").list.std().alias("std")) + shape: (2, 2) + ┌────────────┬──────────┐ + │ values ┆ std │ + │ --- ┆ --- │ + │ list[i64] ┆ f64 │ + ╞════════════╪══════════╡ + │ [-1, 0, 1] ┆ 1.0 │ + │ [1, 10] ┆ 6.363961 │ + └────────────┴──────────┘ + """ + return wrap_expr(self._pyexpr.list_std(ddof)) + + def var(self, ddof: int = 1) -> Expr: + """ + Compute the var value of the lists in the array. + + Parameters + ---------- + ddof + “Delta Degrees of Freedom”: the divisor used in the calculation is N - ddof, + where N represents the number of elements. + By default ddof is 1. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[-1, 0, 1], [1, 10]]}) + >>> df.with_columns(pl.col("values").list.var().alias("var")) + shape: (2, 2) + ┌────────────┬──────┐ + │ values ┆ var │ + │ --- ┆ --- │ + │ list[i64] ┆ f64 │ + ╞════════════╪══════╡ + │ [-1, 0, 1] ┆ 1.0 │ + │ [1, 10] ┆ 40.5 │ + └────────────┴──────┘ + """ + return wrap_expr(self._pyexpr.list_var(ddof)) + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Expr: """ Sort the lists in this column. @@ -368,6 +442,30 @@ def unique(self, *, maintain_order: bool = False) -> Expr: """ return wrap_expr(self._pyexpr.list_unique(maintain_order)) + def n_unique(self) -> Expr: + """ + Count the number of unique values in every sub-lists. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "a": [[1, 1, 2], [2, 3, 4]], + ... } + ... ) + >>> df.with_columns(n_unique=pl.col("a").list.n_unique()) + shape: (2, 2) + ┌───────────┬──────────┐ + │ a ┆ n_unique │ + │ --- ┆ --- │ + │ list[i64] ┆ u32 │ + ╞═══════════╪══════════╡ + │ [1, 1, 2] ┆ 2 │ + │ [2, 3, 4] ┆ 3 │ + └───────────┴──────────┘ + """ + return wrap_expr(self._pyexpr.list_n_unique()) + def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> Expr: """ Concat the arrays in a Series dtype List in linear time. @@ -480,6 +578,50 @@ def gather( indices = parse_as_expression(indices) return wrap_expr(self._pyexpr.list_gather(indices, null_on_oob)) + def gather_every( + self, + n: int | IntoExprColumn, + offset: int | IntoExprColumn = 0, + ) -> Expr: + """ + Take every n-th value start from offset in sublists. + + Parameters + ---------- + n + Gather every n-th element. + offset + Starting index. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "a": [[1, 2, 3, 4, 5], [6, 7, 8], [9, 10, 11, 12]], + ... "n": [2, 1, 3], + ... "offset": [0, 1, 0], + ... } + ... ) + >>> df.with_columns( + ... gather_every=pl.col("a").list.gather_every( + ... n=pl.col("n"), offset=pl.col("offset") + ... ) + ... ) + shape: (3, 4) + ┌───────────────┬─────┬────────┬──────────────┐ + │ a ┆ n ┆ offset ┆ gather_every │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ list[i64] ┆ i64 ┆ i64 ┆ list[i64] │ + ╞═══════════════╪═════╪════════╪══════════════╡ + │ [1, 2, … 5] ┆ 2 ┆ 0 ┆ [1, 3, 5] │ + │ [6, 7, 8] ┆ 1 ┆ 1 ┆ [7, 8] │ + │ [9, 10, … 12] ┆ 3 ┆ 0 ┆ [9, 12] │ + └───────────────┴─────┴────────┴──────────────┘ + """ + n = parse_as_expression(n) + offset = parse_as_expression(offset) + return wrap_expr(self._pyexpr.list_gather_every(n, offset)) + def first(self) -> Expr: """ Get the first value of the sublists. diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 4c5e0eb2eb0c..0c6c3fde46d7 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -6,7 +6,10 @@ from polars.exceptions import ComputeError from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_nonkeyword_arguments +from polars.utils.deprecation import ( + deprecate_nonkeyword_arguments, + deprecate_renamed_function, +) from polars.utils.various import normalize_filepath if TYPE_CHECKING: @@ -218,21 +221,48 @@ def _selector_and(self, other: Expr) -> Expr: return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr)) @overload - def write_json(self, file: None = ...) -> str: + def serialize(self, file: None = ...) -> str: ... @overload - def write_json(self, file: IOBase | str | Path) -> None: + def serialize(self, file: IOBase | str | Path) -> None: ... - def write_json(self, file: IOBase | str | Path | None = None) -> str | None: - """Write expression to json.""" + def serialize(self, file: IOBase | str | Path | None = None) -> str | None: + """ + Serialize this expression to a file or string in JSON format. + + Parameters + ---------- + file + File path to which the result should be written. If set to `None` + (default), the output is returned as a string instead. + + See Also + -------- + Expr.deserialize + + Examples + -------- + Serialize the expression into a JSON string. + + >>> expr = pl.col("foo").sum().over("bar") + >>> json = expr.meta.serialize() + >>> json + '{"Window":{"function":{"Agg":{"Sum":{"Column":"foo"}}},"partition_by":[{"Column":"bar"}],"options":{"Over":"GroupsToRows"}}}' + + The expression can later be deserialized back into an `Expr` object. + + >>> from io import StringIO + >>> pl.Expr.deserialize(StringIO(json)) # doctest: +ELLIPSIS + + """ if isinstance(file, (str, Path)): file = normalize_filepath(file) to_string_io = (file is not None) and isinstance(file, StringIO) if file is None or to_string_io: with BytesIO() as buf: - self._pyexpr.meta_write_json(buf) + self._pyexpr.serialize(buf) json_bytes = buf.getvalue() json_str = json_bytes.decode("utf8") @@ -241,9 +271,27 @@ def write_json(self, file: IOBase | str | Path | None = None) -> str | None: else: return json_str else: - self._pyexpr.meta_write_json(file) + self._pyexpr.serialize(file) return None + @overload + def write_json(self, file: None = ...) -> str: + ... + + @overload + def write_json(self, file: IOBase | str | Path) -> None: + ... + + @deprecate_renamed_function("Expr.meta.serialize", version="0.20.11") + def write_json(self, file: IOBase | str | Path | None = None) -> str | None: + """ + Write expression to json. + + .. deprecated:: 0.20.11 + This method has been renamed to :meth:`serialize`. + """ + return self.serialize(file) + @overload def tree_format(self, *, return_as_string: Literal[False]) -> None: ... diff --git a/py-polars/polars/expr/name.py b/py-polars/polars/expr/name.py index 369bb50b5bb4..482b30ef60ff 100644 --- a/py-polars/polars/expr/name.py +++ b/py-polars/polars/expr/name.py @@ -21,8 +21,11 @@ def keep(self) -> Expr: Notes ----- + This will undo any previous renaming operations on the expression. + Due to implementation constraints, this method can only be called as the last - expression in a chain. + expression in a chain. Only one name operation per expression will work. + Consider using `.name.map` for advanced renaming. See Also -------- @@ -69,6 +72,14 @@ def map(self, function: Callable[[str], str]) -> Expr: """ Rename the output of an expression by mapping a function over the root name. + Notes + ----- + This will undo any previous renaming operations on the expression. + + Due to implementation constraints, this method can only be called as the last + expression in a chain. Only one name operation per expression will work. + + Parameters ---------- function @@ -115,12 +126,14 @@ def prefix(self, prefix: str) -> Expr: prefix Prefix to add to the root column name. + Notes ----- This will undo any previous renaming operations on the expression. Due to implementation constraints, this method can only be called as the last - expression in a chain. + expression in a chain. Only one name operation per expression will work. + Consider using `.name.map` for advanced renaming. See Also -------- @@ -162,7 +175,8 @@ def suffix(self, suffix: str) -> Expr: This will undo any previous renaming operations on the expression. Due to implementation constraints, this method can only be called as the last - expression in a chain. + expression in a chain. Only one name operation per expression will work. + Consider using `.name.map` for advanced renaming. See Also -------- @@ -196,8 +210,11 @@ def to_lowercase(self) -> Expr: Notes ----- + This will undo any previous renaming operations on the expression. + Due to implementation constraints, this method can only be called as the last - expression in a chain. + expression in a chain. Only one name operation per expression will work. + Consider using `.name.map` for advanced renaming. See Also -------- @@ -233,8 +250,11 @@ def to_uppercase(self) -> Expr: Notes ----- + This will undo any previous renaming operations on the expression. + Due to implementation constraints, this method can only be called as the last - expression in a chain. + expression in a chain. Only one name operation per expression will work. + Consider using `.name.map` for advanced renaming. See Also -------- @@ -263,3 +283,66 @@ def to_uppercase(self) -> Expr: └──────┴──────┴──────┴──────┘ """ return self._from_pyexpr(self._pyexpr.name_to_uppercase()) + + def map_fields(self, function: Callable[[str], str]) -> Expr: + """ + Rename fields of a struct by mapping a function over the field name. + + Notes + ----- + This only take effects for struct. + + Parameters + ---------- + function + Function that maps a field name to a new name. + + Examples + -------- + >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + >>> df.select(pl.col("x").name.map_fields(lambda x: x.upper())).schema + OrderedDict({'x': Struct({'A': Int64, 'B': Int64})}) + """ + return self._from_pyexpr(self._pyexpr.name_map_fields(function)) + + def prefix_fields(self, prefix: str) -> Expr: + """ + Add a prefix to all fields name of a struct. + + Notes + ----- + This only take effects for struct. + + Parameters + ---------- + prefix + Prefix to add to the filed name + + Examples + -------- + >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + >>> df.select(pl.col("x").name.prefix_fields("prefix_")).schema + OrderedDict({'x': Struct({'prefix_a': Int64, 'prefix_b': Int64})}) + """ + return self._from_pyexpr(self._pyexpr.name_prefix_fields(prefix)) + + def suffix_fields(self, suffix: str) -> Expr: + """ + Add a suffix to all fields name of a struct. + + Notes + ----- + This only take effects for struct. + + Parameters + ---------- + suffix + Suffix to add to the filed name + + Examples + -------- + >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + >>> df.select(pl.col("x").name.suffix_fields("_suffix")).schema + OrderedDict({'x': Struct({'a_suffix': Int64, 'b_suffix': Int64})}) + """ + return self._from_pyexpr(self._pyexpr.name_suffix_fields(suffix)) diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 2b2f7b16baed..aa2fcf6d49e3 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -6,6 +6,7 @@ import polars._reexport as pl from polars import functions as F from polars.datatypes import Date, Datetime, Int32, Time, py_type_to_dtype +from polars.datatypes.constants import N_INFER_DEFAULT from polars.exceptions import ChronoFormatWarning from polars.utils._parse_expr_input import parse_as_expression from polars.utils._wrap import wrap_expr @@ -428,6 +429,12 @@ def len_chars(self) -> Expr: equivalent output with much better performance: :func:`len_bytes` runs in _O(1)_, while :func:`len_chars` runs in (_O(n)_). + A character is defined as a `Unicode scalar value`_. A single character is + represented by a single byte when working with ASCII text, and a maximum of + 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + Examples -------- >>> df = pl.DataFrame({"a": ["Café", "345", "東京", None]}) @@ -1250,7 +1257,9 @@ def starts_with(self, prefix: str | Expr) -> Expr: return wrap_expr(self._pyexpr.str_starts_with(prefix)) def json_decode( - self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + self, + dtype: PolarsDataType | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Expr: """ Parse string values as JSON. @@ -1263,8 +1272,8 @@ def json_decode( The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. infer_schema_length - How many rows to parse to determine the schema. - If `None` all rows are used. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. See Also -------- @@ -1337,8 +1346,8 @@ def json_path_match(self, json_path: str) -> Expr: return wrap_expr(self._pyexpr.str_json_path_match(json_path)) def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -1347,6 +1356,26 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Expr + Expression of data type :class:`Binary`. + + Examples + -------- + >>> df = pl.DataFrame({"color": ["000000", "ffff00", "0000ff"]}) + >>> df.with_columns(pl.col("color").str.decode("hex").alias("decoded")) + shape: (3, 2) + ┌────────┬─────────────────┐ + │ color ┆ decoded │ + │ --- ┆ --- │ + │ str ┆ binary │ + ╞════════╪═════════════════╡ + │ 000000 ┆ b"\x00\x00\x00" │ + │ ffff00 ┆ b"\xff\xff\x00" │ + │ 0000ff ┆ b"\x00\x00\xff" │ + └────────┴─────────────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.str_hex_decode(strict)) @@ -1358,7 +1387,7 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: def encode(self, encoding: TransferEncoding) -> Expr: """ - Encode a value using the provided encoding. + Encode values using the provided encoding. Parameters ---------- @@ -1917,14 +1946,58 @@ def replace( value String that will replace the matched substring. literal - Treat pattern as a literal string. + Treat `pattern` as a literal string. n Number of matches to replace. + See Also + -------- + replace_all + Notes ----- + The dollar sign (`$`) is a special character related to capture groups. + To refer to a literal dollar sign, use `$$` instead or set `literal` to `True`. + To modify regular expression behaviour (such as case-sensitivity) with flags, - use the inline `(?iLmsuxU)` syntax. For example: + use the inline `(?iLmsuxU)` syntax. See the regex crate's section on + `grouping and flags `_ + for additional information about the use of inline expression modifiers. + + Examples + -------- + >>> df = pl.DataFrame({"id": [1, 2], "text": ["123abc", "abc456"]}) + >>> df.with_columns(pl.col("text").str.replace(r"abc\b", "ABC")) + shape: (2, 2) + ┌─────┬────────┐ + │ id ┆ text │ + │ --- ┆ --- │ + │ i64 ┆ str │ + ╞═════╪════════╡ + │ 1 ┆ 123ABC │ + │ 2 ┆ abc456 │ + └─────┴────────┘ + + Capture groups are supported. Use `${1}` in the `value` string to refer to the + first capture group in the `pattern`, `${2}` to refer to the second capture + group, and so on. You can also use named capture groups. + + >>> df = pl.DataFrame({"word": ["hat", "hut"]}) + >>> df.with_columns( + ... positional=pl.col.word.str.replace("h(.)t", "b${1}d"), + ... named=pl.col.word.str.replace("h(?.)t", "b${vowel}d"), + ... ) + shape: (2, 3) + ┌──────┬────────────┬───────┐ + │ word ┆ positional ┆ named │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ str │ + ╞══════╪════════════╪═══════╡ + │ hat ┆ bad ┆ bad │ + │ hut ┆ bud ┆ bud │ + └──────┴────────────┴───────┘ + + Apply case-insensitive string replacement using the `(?i)` flag. >>> df = pl.DataFrame( ... { @@ -1934,7 +2007,6 @@ def replace( ... } ... ) >>> df.with_columns( - ... # apply case-insensitive string replacement ... pl.col("weather").str.replace(r"(?i)foggy|rainy|cloudy|snowy", "Sunny") ... ) shape: (4, 3) @@ -1948,30 +2020,6 @@ def replace( │ Philadelphia ┆ Autumn ┆ Sunny │ │ Philadelphia ┆ Winter ┆ Sunny │ └──────────────┴────────┴─────────┘ - - See the regex crate's section on `grouping and flags - `_ for - additional information about the use of inline expression modifiers. - - See Also - -------- - replace_all : Replace all matching regex/literal substrings. - - Examples - -------- - >>> df = pl.DataFrame({"id": [1, 2], "text": ["123abc", "abc456"]}) - >>> df.with_columns( - ... pl.col("text").str.replace(r"abc\b", "ABC") - ... ) # doctest: +IGNORE_RESULT - shape: (2, 2) - ┌─────┬────────┐ - │ id ┆ text │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞═════╪════════╡ - │ 1 ┆ 123ABC │ - │ 2 ┆ abc456 │ - └─────┴────────┘ """ pattern = parse_as_expression(pattern, str_as_lit=True) value = parse_as_expression(value, str_as_lit=True) @@ -1980,7 +2028,7 @@ def replace( def replace_all( self, pattern: str | Expr, value: str | Expr, *, literal: bool = False ) -> Expr: - """ + r""" Replace all matching regex/literal substrings with a new string value. Parameters @@ -1989,13 +2037,23 @@ def replace_all( A valid regular expression pattern, compatible with the `regex crate `_. value - Replacement string. + String that will replace the matched substring. literal - Treat pattern as a literal string. + Treat `pattern` as a literal string. See Also -------- - replace : Replace first matching regex/literal substring. + replace + + Notes + ----- + The dollar sign (`$`) is a special character related to capture groups. + To refer to a literal dollar sign, use `$$` instead or set `literal` to `True`. + + To modify regular expression behaviour (such as case-sensitivity) with flags, + use the inline `(?iLmsuxU)` syntax. See the regex crate's section on + `grouping and flags `_ + for additional information about the use of inline expression modifiers. Examples -------- @@ -2010,6 +2068,52 @@ def replace_all( │ 1 ┆ -bc-bc │ │ 2 ┆ 123-123 │ └─────┴─────────┘ + + Capture groups are supported. Use `${1}` in the `value` string to refer to the + first capture group in the `pattern`, `${2}` to refer to the second capture + group, and so on. You can also use named capture groups. + + >>> df = pl.DataFrame({"word": ["hat", "hut"]}) + >>> df.with_columns( + ... positional=pl.col.word.str.replace_all("h(.)t", "b${1}d"), + ... named=pl.col.word.str.replace_all("h(?.)t", "b${vowel}d"), + ... ) + shape: (2, 3) + ┌──────┬────────────┬───────┐ + │ word ┆ positional ┆ named │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ str │ + ╞══════╪════════════╪═══════╡ + │ hat ┆ bad ┆ bad │ + │ hut ┆ bud ┆ bud │ + └──────┴────────────┴───────┘ + + Apply case-insensitive string replacement using the `(?i)` flag. + + >>> df = pl.DataFrame( + ... { + ... "city": "Philadelphia", + ... "season": ["Spring", "Summer", "Autumn", "Winter"], + ... "weather": ["Rainy", "Sunny", "Cloudy", "Snowy"], + ... } + ... ) + >>> df.with_columns( + ... # apply case-insensitive string replacement + ... pl.col("weather").str.replace_all( + ... r"(?i)foggy|rainy|cloudy|snowy", "Sunny" + ... ) + ... ) + shape: (4, 3) + ┌──────────────┬────────┬─────────┐ + │ city ┆ season ┆ weather │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ str │ + ╞══════════════╪════════╪═════════╡ + │ Philadelphia ┆ Spring ┆ Sunny │ + │ Philadelphia ┆ Summer ┆ Sunny │ + │ Philadelphia ┆ Autumn ┆ Sunny │ + │ Philadelphia ┆ Winter ┆ Sunny │ + └──────────────┴────────┴─────────┘ """ pattern = parse_as_expression(pattern, str_as_lit=True) value = parse_as_expression(value, str_as_lit=True) @@ -2040,7 +2144,7 @@ def slice( self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None ) -> Expr: """ - Create subslices of the string values of a String Series. + Extract a substring from each string value. Parameters ---------- @@ -2055,40 +2159,45 @@ def slice( Expr Expression of data type :class:`String`. + Notes + ----- + Both the `offset` and `length` inputs are defined in terms of the number + of characters in the (UTF8) string. A character is defined as a + `Unicode scalar value`_. A single character is represented by a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + Examples -------- >>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]}) - >>> df.with_columns( - ... pl.col("s").str.slice(-3).alias("s_sliced"), - ... ) + >>> df.with_columns(pl.col("s").str.slice(-3).alias("slice")) shape: (4, 2) - ┌─────────────┬──────────┐ - │ s ┆ s_sliced │ - │ --- ┆ --- │ - │ str ┆ str │ - ╞═════════════╪══════════╡ - │ pear ┆ ear │ - │ null ┆ null │ - │ papaya ┆ aya │ - │ dragonfruit ┆ uit │ - └─────────────┴──────────┘ + ┌─────────────┬───────┐ + │ s ┆ slice │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪═══════╡ + │ pear ┆ ear │ + │ null ┆ null │ + │ papaya ┆ aya │ + │ dragonfruit ┆ uit │ + └─────────────┴───────┘ Using the optional `length` parameter - >>> df.with_columns( - ... pl.col("s").str.slice(4, length=3).alias("s_sliced"), - ... ) + >>> df.with_columns(pl.col("s").str.slice(4, length=3).alias("slice")) shape: (4, 2) - ┌─────────────┬──────────┐ - │ s ┆ s_sliced │ - │ --- ┆ --- │ - │ str ┆ str │ - ╞═════════════╪══════════╡ - │ pear ┆ │ - │ null ┆ null │ - │ papaya ┆ ya │ - │ dragonfruit ┆ onf │ - └─────────────┴──────────┘ + ┌─────────────┬───────┐ + │ s ┆ slice │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪═══════╡ + │ pear ┆ │ + │ null ┆ null │ + │ papaya ┆ ya │ + │ dragonfruit ┆ onf │ + └─────────────┴───────┘ """ offset = parse_as_expression(offset) length = parse_as_expression(length) @@ -2125,7 +2234,7 @@ def explode(self) -> Expr: def to_integer(self, *, base: int = 10, strict: bool = True) -> Expr: """ - Convert an String column into an Int64 column with base radix. + Convert a String column into an Int64 column with base radix. Parameters ---------- @@ -2327,7 +2436,9 @@ def rjust(self, length: int, fill_char: str = " ") -> Expr: @deprecate_renamed_function("json_decode", version="0.19.12") def json_extract( - self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + self, + dtype: PolarsDataType | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Expr: """ Parse string values as JSON. @@ -2341,8 +2452,8 @@ def json_extract( The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. infer_schema_length - How many rows to parse to determine the schema. - If `None` all rows are used. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. """ return self.json_decode(dtype, infer_schema_length) diff --git a/py-polars/polars/expr/whenthen.py b/py-polars/polars/expr/whenthen.py index f357289fec46..ced5df6d68fe 100644 --- a/py-polars/polars/expr/whenthen.py +++ b/py-polars/polars/expr/whenthen.py @@ -24,6 +24,8 @@ class When: In this state, `then` must be called to continue to finish the expression. """ + __slots__ = ("_when",) + def __init__(self, when: Any): self._when = when @@ -49,6 +51,8 @@ class Then(Expr): Represents the state of the expression after `pl.when(...).then(...)` is called. """ + __slots__ = ("_then",) + def __init__(self, then: Any): self._then = then @@ -106,6 +110,8 @@ class ChainedWhen(Expr): In this state, `then` must be called to continue to finish the expression. """ + __slots__ = ("_chained_when",) + def __init__(self, chained_when: Any): self._chained_when = chained_when @@ -131,6 +137,8 @@ class ChainedThen(Expr): Represents the state of the expression after an additional `then` is called. """ + __slots__ = ("_chained_then",) + def __init__(self, chained_then: Any): self._chained_then = chained_then diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index 22403b09ca77..935cdfe159a6 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -9,6 +9,7 @@ cumsum_horizontal, max, max_horizontal, + mean_horizontal, min, min_horizontal, sum, @@ -156,6 +157,7 @@ "map_batches", "map_groups", "mean", + "mean_horizontal", "median", "n_unique", "quantile", diff --git a/py-polars/polars/functions/aggregation/__init__.py b/py-polars/polars/functions/aggregation/__init__.py index 9f99611f9121..1d50e9770d83 100644 --- a/py-polars/polars/functions/aggregation/__init__.py +++ b/py-polars/polars/functions/aggregation/__init__.py @@ -4,6 +4,7 @@ cum_sum_horizontal, cumsum_horizontal, max_horizontal, + mean_horizontal, min_horizontal, sum_horizontal, ) @@ -30,6 +31,7 @@ "cum_sum_horizontal", "cumsum_horizontal", "max_horizontal", + "mean_horizontal", "min_horizontal", "sum_horizontal", ] diff --git a/py-polars/polars/functions/aggregation/horizontal.py b/py-polars/polars/functions/aggregation/horizontal.py index 023099a6bb28..6d06aab8162c 100644 --- a/py-polars/polars/functions/aggregation/horizontal.py +++ b/py-polars/polars/functions/aggregation/horizontal.py @@ -27,26 +27,35 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: Column(s) to use in the aggregation. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. + Notes + ----- + `Kleene logic`_ is used to deal with nulls: if the column contains any null values + and no `False` values, the output is null. + + .. _Kleene logic: https://en.wikipedia.org/wiki/Three-valued_logic + Examples -------- >>> df = pl.DataFrame( ... { - ... "a": [False, False, True, True], - ... "b": [False, True, None, True], - ... "c": ["w", "x", "y", "z"], + ... "a": [False, False, True, True, False, None], + ... "b": [False, True, True, None, None, None], + ... "c": ["u", "v", "w", "x", "y", "z"], ... } ... ) >>> df.with_columns(all=pl.all_horizontal("a", "b")) - shape: (4, 4) + shape: (6, 4) ┌───────┬───────┬─────┬───────┐ │ a ┆ b ┆ c ┆ all │ │ --- ┆ --- ┆ --- ┆ --- │ │ bool ┆ bool ┆ str ┆ bool │ ╞═══════╪═══════╪═════╪═══════╡ - │ false ┆ false ┆ w ┆ false │ - │ false ┆ true ┆ x ┆ false │ - │ true ┆ null ┆ y ┆ null │ - │ true ┆ true ┆ z ┆ true │ + │ false ┆ false ┆ u ┆ false │ + │ false ┆ true ┆ v ┆ false │ + │ true ┆ true ┆ w ┆ true │ + │ true ┆ null ┆ x ┆ null │ + │ false ┆ null ┆ y ┆ false │ + │ null ┆ null ┆ z ┆ null │ └───────┴───────┴─────┴───────┘ """ pyexprs = parse_as_list_of_expressions(*exprs) @@ -63,25 +72,34 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: Column(s) to use in the aggregation. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. + Notes + ----- + `Kleene logic`_ is used to deal with nulls: if the column contains any null values + and no `True` values, the output is null. + + .. _Kleene logic: https://en.wikipedia.org/wiki/Three-valued_logic + Examples -------- >>> df = pl.DataFrame( ... { - ... "a": [False, False, True, None], - ... "b": [False, True, None, None], - ... "c": ["w", "x", "y", "z"], + ... "a": [False, False, True, True, False, None], + ... "b": [False, True, True, None, None, None], + ... "c": ["u", "v", "w", "x", "y", "z"], ... } ... ) >>> df.with_columns(any=pl.any_horizontal("a", "b")) - shape: (4, 4) + shape: (6, 4) ┌───────┬───────┬─────┬───────┐ │ a ┆ b ┆ c ┆ any │ │ --- ┆ --- ┆ --- ┆ --- │ │ bool ┆ bool ┆ str ┆ bool │ ╞═══════╪═══════╪═════╪═══════╡ - │ false ┆ false ┆ w ┆ false │ - │ false ┆ true ┆ x ┆ true │ - │ true ┆ null ┆ y ┆ true │ + │ false ┆ false ┆ u ┆ false │ + │ false ┆ true ┆ v ┆ true │ + │ true ┆ true ┆ w ┆ true │ + │ true ┆ null ┆ x ┆ true │ + │ false ┆ null ┆ y ┆ null │ │ null ┆ null ┆ z ┆ null │ └───────┴───────┴─────┴───────┘ """ @@ -194,6 +212,41 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return wrap_expr(plr.sum_horizontal(pyexprs)) +def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: + """ + Compute the mean of all values horizontally across columns. + + Parameters + ---------- + *exprs + Column(s) to use in the aggregation. Accepts expression input. Strings are + parsed as column names, other non-expression inputs are parsed as literals. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "a": [1, 8, 3], + ... "b": [4, 5, None], + ... "c": ["x", "y", "z"], + ... } + ... ) + >>> df.with_columns(mean=pl.mean_horizontal("a", "b")) + shape: (3, 4) + ┌─────┬──────┬─────┬──────┐ + │ a ┆ b ┆ c ┆ mean │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str ┆ f64 │ + ╞═════╪══════╪═════╪══════╡ + │ 1 ┆ 4 ┆ x ┆ 2.5 │ + │ 8 ┆ 5 ┆ y ┆ 6.5 │ + │ 3 ┆ null ┆ z ┆ 3.0 │ + └─────┴──────┴─────┴──────┘ + """ + pyexprs = parse_as_list_of_expressions(*exprs) + return wrap_expr(plr.mean_horizontal(pyexprs)) + + def cum_sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: """ Cumulatively sum all values horizontally across columns. diff --git a/py-polars/polars/functions/aggregation/vertical.py b/py-polars/polars/functions/aggregation/vertical.py index 2137b83e0371..16828027f3dd 100644 --- a/py-polars/polars/functions/aggregation/vertical.py +++ b/py-polars/polars/functions/aggregation/vertical.py @@ -24,8 +24,8 @@ def all(*names: str, ignore_nulls: bool = True) -> Expr: Ignore null values (default). If set to `False`, `Kleene logic`_ is used to deal with nulls: - if the column contains any null values and no `True` values, - the output is `None`. + if the column contains any null values and no `False` values, + the output is null. .. _Kleene logic: https://en.wikipedia.org/wiki/Three-valued_logic @@ -90,7 +90,7 @@ def any(*names: str, ignore_nulls: bool = True) -> Expr | bool | None: If set to `False`, `Kleene logic`_ is used to deal with nulls: if the column contains any null values and no `True` values, - the output is `None`. + the output is null. .. _Kleene logic: https://en.wikipedia.org/wiki/Three-valued_logic diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 6aa675f444cc..bbd1e1d14b38 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -97,8 +97,8 @@ def count(*columns: str) -> Expr: This function is syntactic sugar for `col(columns).count()`. Calling this function without any arguments returns the number of rows in the - context. **This way of using the function is deprecated. Please use :func:`len` - instead.** + context. **This way of using the function is deprecated.** Please use :func:`len` + instead. Parameters ---------- @@ -146,7 +146,7 @@ def count(*columns: str) -> Expr: └─────┴─────┘ Return the number of rows in a context. **This way of using the function is - deprecated. Please use :func:`len` instead.** + deprecated.** Please use :func:`len` instead. >>> df.select(pl.count()) # doctest: +SKIP shape: (1, 1) @@ -346,6 +346,10 @@ def mean(*columns: str) -> Expr: *columns One or more column names. + See Also + -------- + mean_horizontal + Examples -------- >>> df = pl.DataFrame( diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index 83075f5f317e..fc03ee0fa005 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -10,6 +10,7 @@ from polars.dependencies import numpy as np from polars.utils._wrap import wrap_expr from polars.utils.convert import ( + _date_to_pl_date, _datetime_to_pl_timestamp, _time_to_pl_time, _timedelta_to_pl_timedelta, @@ -35,7 +36,8 @@ def lit( value Value that should be used as a `literal`. dtype - Optionally define a dtype. + The data type of the resulting expression. + If set to `None` (default), the data type is inferred from the `value` input. allow_object If type is unknown use an 'object' type. By default, we will raise a `ValueException` @@ -43,7 +45,7 @@ def lit( Notes ----- - Expected datatypes + Expected datatypes: - `pl.lit([])` -> empty Series Float32 - `pl.lit([1, 2, 3])` -> Series Int64 @@ -75,41 +77,44 @@ def lit( time_unit: TimeUnit if isinstance(value, datetime): - time_unit = "us" if dtype is None else getattr(dtype, "time_unit", "us") - time_zone = ( - value.tzinfo - if getattr(dtype, "time_zone", None) is None - else getattr(dtype, "time_zone", None) - ) - if ( - value.tzinfo is not None - and getattr(dtype, "time_zone", None) is not None - and dtype.time_zone != str(value.tzinfo) # type: ignore[union-attr] - ): - msg = f"time zone of dtype ({dtype.time_zone!r}) differs from time zone of value ({value.tzinfo!r})" # type: ignore[union-attr] - raise TypeError(msg) - e = lit( - _datetime_to_pl_timestamp(value.replace(tzinfo=timezone.utc), time_unit) - ).cast(Datetime(time_unit)) + if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: + time_unit = tu # type: ignore[assignment] + else: + time_unit = "us" + + time_zone: str | None = getattr(dtype, "time_zone", None) + if (tzinfo := value.tzinfo) is not None: + tzinfo_str = str(tzinfo) + if time_zone is not None and time_zone != tzinfo_str: + msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})" + raise TypeError(msg) + time_zone = tzinfo_str + + dt_utc = value.replace(tzinfo=timezone.utc) + dt_int = _datetime_to_pl_timestamp(dt_utc, time_unit) + expr = lit(dt_int).cast(Datetime(time_unit)) if time_zone is not None: - return e.dt.replace_time_zone( - str(time_zone), ambiguous="earliest" if value.fold == 0 else "latest" + expr = expr.dt.replace_time_zone( + time_zone, ambiguous="earliest" if value.fold == 0 else "latest" ) - else: - return e + return expr elif isinstance(value, timedelta): - if dtype is None or (time_unit := getattr(dtype, "time_unit", "us")) is None: + if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: + time_unit = tu # type: ignore[assignment] + else: time_unit = "us" - return lit(_timedelta_to_pl_timedelta(value, time_unit)).cast( - Duration(time_unit) - ) + + td_int = _timedelta_to_pl_timedelta(value, time_unit) + return lit(td_int).cast(Duration(time_unit)) elif isinstance(value, time): - return lit(_time_to_pl_time(value)).cast(Time) + time_int = _time_to_pl_time(value) + return lit(time_int).cast(Time) elif isinstance(value, date): - return lit(datetime(value.year, value.month, value.day)).cast(Date) + date_int = _date_to_pl_date(value) + return lit(date_int).cast(Date) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index ac2123ecc506..2efbb3616d48 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -5,7 +5,8 @@ from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path -from typing import IO, Any, ContextManager, Iterator, overload +from tempfile import NamedTemporaryFile +from typing import IO, Any, ContextManager, Iterator, cast, overload from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError @@ -31,7 +32,7 @@ def _is_local_file(file: str) -> bool: @overload def _prepare_file_arg( - file: str | list[str] | Path | IO[bytes] | bytes, + file: str | Path | list[str] | IO[bytes] | bytes, encoding: str | None = ..., *, use_pyarrow: bool = ..., @@ -55,7 +56,7 @@ def _prepare_file_arg( @overload def _prepare_file_arg( - file: str | list[str] | Path | IO[str] | IO[bytes] | bytes, + file: str | Path | list[str] | IO[str] | IO[bytes] | bytes, encoding: str | None = ..., *, use_pyarrow: bool = ..., @@ -66,7 +67,7 @@ def _prepare_file_arg( def _prepare_file_arg( - file: str | list[str] | Path | IO[str] | IO[bytes] | bytes, + file: str | Path | list[str] | IO[str] | IO[bytes] | bytes, encoding: str | None = None, *, use_pyarrow: bool = False, @@ -234,3 +235,44 @@ def _process_file_url(path: str, encoding: str | None = None) -> BytesIO: return BytesIO(f.read()) else: return BytesIO(f.read().decode(encoding).encode("utf8")) + + +@contextmanager +def PortableTemporaryFile( + mode: str = "w+b", + *, + buffering: int = -1, + encoding: str | None = None, + newline: str | None = None, + suffix: str | None = None, + prefix: str | None = None, + dir: str | Path | None = None, + delete: bool = True, + errors: str | None = None, +) -> Iterator[Any]: + """ + Slightly more resilient version of the standard `NamedTemporaryFile`. + + Plays better with Windows when using the 'delete' option. + """ + params = cast( + Any, + { + "mode": mode, + "buffering": buffering, + "encoding": encoding, + "newline": newline, + "suffix": suffix, + "prefix": prefix, + "dir": dir, + "delete": False, + "errors": errors, + }, + ) + tmp = NamedTemporaryFile(**params) + try: + yield tmp + finally: + tmp.close() + if delete: + Path(tmp.name).unlink(missing_ok=True) diff --git a/py-polars/polars/io/avro.py b/py-polars/polars/io/avro.py index e93667ee00ae..a25be704c4f2 100644 --- a/py-polars/polars/io/avro.py +++ b/py-polars/polars/io/avro.py @@ -1,18 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, BinaryIO +from typing import IO, TYPE_CHECKING import polars._reexport as pl if TYPE_CHECKING: - from io import BytesIO from pathlib import Path from polars import DataFrame def read_avro( - source: str | Path | BytesIO | BinaryIO, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index 84ad7ba57b09..a4f416e8a268 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -from pathlib import Path from typing import TYPE_CHECKING, Sequence from polars.datatypes import N_INFER_DEFAULT, py_type_to_dtype @@ -18,6 +17,8 @@ from polars.polars import PyBatchedCsv if TYPE_CHECKING: + from pathlib import Path + from polars import DataFrame from polars.type_aliases import CsvEncoding, PolarsDataType, SchemaDict @@ -56,9 +57,7 @@ def __init__( raise_if_empty: bool = True, truncate_ragged_lines: bool = False, ): - path: str | None - if isinstance(source, (str, Path)): - path = normalize_filepath(source) + path = normalize_filepath(source) dtype_list: Sequence[tuple[str, PolarsDataType]] | None = None dtype_slice: Sequence[PolarsDataType] | None = None diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index cd91df36dcd9..e3ad2cd5c80f 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Mapping, Sequence, TextIO +from typing import IO, TYPE_CHECKING, Any, Callable, Mapping, Sequence import polars._reexport as pl from polars.datatypes import N_INFER_DEFAULT, String @@ -12,8 +12,6 @@ from polars.utils.various import handle_projection_columns, normalize_filepath if TYPE_CHECKING: - from io import BytesIO - from polars import DataFrame, LazyFrame from polars.type_aliases import CsvEncoding, PolarsDataType, SchemaDict @@ -24,7 +22,7 @@ old_name="comment_char", new_name="comment_prefix", version="0.19.14" ) def read_csv( - source: str | TextIO | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[str] | IO[bytes] | bytes, *, has_header: bool = True, columns: Sequence[int] | Sequence[str] | None = None, @@ -101,6 +99,7 @@ def read_csv( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. @@ -120,12 +119,9 @@ def read_csv( Number of threads to use in csv parsing. Defaults to the number of physical cpu's of your system. infer_schema_length - Maximum number of lines to read to infer schema. - If schema is inferred wrongly (e.g. as `pl.Int64` instead of `pl.Float64`), - try to increase the number of lines used to infer the schema or override - inferred dtype for those columns with `dtypes`. - If set to 0, all columns will be read as `pl.String`. - If set to `None`, a full table scan will be done (slow). + The maximum number of rows to scan for schema inference. + If set to `0`, all columns will be read as `pl.String`. + If set to `None`, the full data may be scanned *(this is slow)*. batch_size Number of lines to read into the buffer at once. Modify this to change performance. @@ -185,10 +181,14 @@ def read_csv( Notes ----- - This operation defaults to a `rechunk` operation at the end, meaning that - all data will be stored continuously in memory. - Set `rechunk=False` if you are benchmarking the csv-reader. A `rechunk` is - an expensive operation. + If the schema is inferred incorrectly (e.g. as `pl.Int64` instead of `pl.Float64`), + try to increase the number of lines used to infer the schema with + `infer_schema_length` or override the inferred dtype for those columns with + `dtypes`. + + This operation defaults to a `rechunk` operation at the end, meaning that all data + will be stored continuously in memory. Set `rechunk=False` if you are benchmarking + the csv-reader. A `rechunk` is an expensive operation. Examples -------- @@ -508,6 +508,7 @@ def read_csv_batched( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. @@ -523,9 +524,9 @@ def read_csv_batched( Number of threads to use in csv parsing. Defaults to the number of physical cpu's of your system. infer_schema_length - Maximum number of lines to read to infer schema. - If set to 0, all columns will be read as `pl.String`. - If set to `None`, a full table scan will be done (slow). + The maximum number of rows to scan for schema inference. + If set to `0`, all columns will be read as `pl.String`. + If set to `None`, the full data may be scanned *(this is slow)*. batch_size Number of lines to read into the buffer at once. @@ -801,6 +802,7 @@ def scan_csv( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. @@ -814,9 +816,9 @@ def scan_csv( Apply a function over the column names just in time (when they are determined); this function will receive (and should return) a list of column names. infer_schema_length - Maximum number of lines to read to infer schema. - If set to 0, all columns will be read as `pl.String`. - If set to `None`, a full table scan will be done (slow). + The maximum number of rows to scan for schema inference. + If set to `0`, all columns will be read as `pl.String`. + If set to `None`, the full data may be scanned *(this is slow)*. n_rows Stop reading from CSV file after reading `n_rows`. encoding : {'utf8', 'utf8-lossy'} diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 8ed49bd69216..76db913aeb02 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -151,7 +151,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - # iif we created it and are finished with it, we can + # if we created it and are finished with it, we can # close the cursor (but NOT the connection) if self.can_close_cursor: self.cursor.close() @@ -279,21 +279,30 @@ def _from_rows( from polars import DataFrame if hasattr(self.result, "fetchall"): - description = ( - self.result.cursor.description - if self.driver_name == "sqlalchemy" - else self.result.description - ) - column_names = [desc[0] for desc in description] + if self.driver_name == "sqlalchemy": + if hasattr(self.result, "cursor"): + cursor_desc = {d[0]: d[1] for d in self.result.cursor.description} + elif hasattr(self.result, "_metadata"): + cursor_desc = {k: None for k in self.result._metadata.keys} + else: + msg = f"Unable to determine metadata from query result; {self.result!r}" + raise ValueError(msg) + else: + cursor_desc = {d[0]: d[1] for d in self.result.description} + + # TODO: refine types based on the cursor description's type_code, + # if/where available? (for now, we just read the column names) + result_columns = list(cursor_desc) + frames = ( DataFrame( data=rows, - schema=column_names, + schema=result_columns, schema_overrides=schema_overrides, orient="row", ) for rows in ( - self._fetchmany_rows(self.result, batch_size) + list(self._fetchmany_rows(self.result, batch_size)) if iter_batches else [self._fetchall_rows(self.result)] # type: ignore[list-item] ) @@ -318,18 +327,33 @@ def execute( options = options or {} cursor_execute = self.cursor.execute - if self.driver_name == "sqlalchemy" and isinstance(query, str): - params = options.get("parameters") - if isinstance(params, Sequence) and hasattr(self.cursor, "exec_driver_sql"): - cursor_execute = self.cursor.exec_driver_sql - if isinstance(params, list) and not all( - isinstance(p, (dict, tuple)) for p in params + if self.driver_name == "sqlalchemy": + from sqlalchemy.orm import Session + + param_key = "parameters" + if ( + isinstance(self.cursor, Session) + and "parameters" in options + and "params" not in options + ): + options = options.copy() + options["params"] = options.pop("parameters") + param_key = "params" + + if isinstance(query, str): + params = options.get(param_key) + if isinstance(params, Sequence) and hasattr( + self.cursor, "exec_driver_sql" ): - options["parameters"] = tuple(params) - else: - from sqlalchemy.sql import text + cursor_execute = self.cursor.exec_driver_sql + if isinstance(params, list) and not all( + isinstance(p, (dict, tuple)) for p in params + ): + options[param_key] = tuple(params) + else: + from sqlalchemy.sql import text - query = text(query) # type: ignore[assignment] + query = text(query) # type: ignore[assignment] # note: some cursor execute methods (eg: sqlite3) only take positional # params, hence the slightly convoluted resolution of the 'options' dict @@ -440,9 +464,10 @@ def read_database( # noqa: D417 be a suitable "Selectable", otherwise it is expected to be a string). connection An instantiated connection (or cursor/client object) that the query can be - executed against. Can also pass a valid ODBC connection string, starting with - "Driver=", in which case the `arrow-odbc` package will be used to establish - the connection and return Arrow-native data to Polars. + executed against. Can also pass a valid ODBC connection string, identified as + such if it contains the string "Driver=", in which case the `arrow-odbc` + package will be used to establish the connection and return Arrow-native data + to Polars. iter_batches Return an iterator of DataFrames, where each DataFrame represents a batch of data returned by the query; this can be useful for processing large resultsets @@ -489,7 +514,8 @@ def read_database( # noqa: D417 `connectorx` will optimise translation of the result set into Arrow format in Rust, whereas these libraries will return row-wise data to Python *before* we can load into Arrow. Note that you can easily determine the connection's - URI from a SQLAlchemy engine object by calling `str(conn.engine.url)`. + URI from a SQLAlchemy engine object by calling + `conn.engine.url.render_as_string(hide_password=False)`. * If polars has to create a cursor from your connection in order to execute the query then that cursor will be automatically closed when the query completes; @@ -541,7 +567,7 @@ def read_database( # noqa: D417 """ # noqa: W505 if isinstance(connection, str): # check for odbc connection string - if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="): + if re.search(r"\bdriver\s*=\s*{[^}]+?}", connection, re.IGNORECASE): try: import arrow_odbc # noqa: F401 except ModuleNotFoundError: diff --git a/py-polars/polars/io/delta.py b/py-polars/polars/io/delta.py index ba835a3b829a..f132fb589944 100644 --- a/py-polars/polars/io/delta.py +++ b/py-polars/polars/io/delta.py @@ -45,7 +45,7 @@ def read_delta( For cloud storages, this may include configurations for authentication etc. More info is available `here - `__. + `__. delta_table_options Additional keyword arguments while reading a Delta lake Table. pyarrow_options @@ -167,7 +167,7 @@ def scan_delta( For cloud storages, this may include configurations for authentication etc. More info is available `here - `__. + `__. delta_table_options Additional keyword arguments while reading a Delta lake Table. pyarrow_options @@ -291,7 +291,7 @@ def _get_delta_lake_table( Notes ----- Make sure to install deltalake>=0.8.0. Read the documentation - `here `_. + `here `_. """ _check_if_delta_available() diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 8ca0c5b3af4b..411ed43043d9 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -2,7 +2,7 @@ import contextlib from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, BinaryIO +from typing import IO, TYPE_CHECKING, Any import polars._reexport as pl from polars.dependencies import _PYARROW_AVAILABLE @@ -14,15 +14,13 @@ from polars.polars import read_ipc_schema as _read_ipc_schema if TYPE_CHECKING: - from io import BytesIO - from polars import DataFrame, DataType, LazyFrame @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_ipc( - source: str | BinaryIO | BytesIO | Path | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -114,7 +112,7 @@ def read_ipc( @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_ipc_stream( - source: str | BinaryIO | BytesIO | Path | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, diff --git a/py-polars/polars/io/json.py b/py-polars/polars/io/json.py index 3f083ed6bfa1..099c6936e780 100644 --- a/py-polars/polars/io/json.py +++ b/py-polars/polars/io/json.py @@ -16,9 +16,9 @@ def read_json( source: str | Path | IOBase | bytes, *, - infer_schema_length: int | None = N_INFER_DEFAULT, schema: SchemaDefinition | None = None, schema_overrides: SchemaDefinition | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> DataFrame: """ Read into a DataFrame from a JSON file. @@ -29,8 +29,6 @@ def read_json( Path to a file or a file-like object (by file-like object, we refer to objects that have a `read()` method, such as a file handler (e.g. via builtin `open` function) or `BytesIO`). - infer_schema_length - Infer the schema from the first `infer_schema_length` rows. schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict The DataFrame schema may be declared in several ways: @@ -44,6 +42,9 @@ def read_json( schema_overrides : dict, default None Support type specification or override of one or more columns; note that any dtypes inferred from the schema param will be overridden. + infer_schema_length + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. See Also -------- @@ -51,7 +52,7 @@ def read_json( """ return pl.DataFrame._read_json( source, - infer_schema_length=infer_schema_length, schema=schema, schema_overrides=schema_overrides, + infer_schema_length=infer_schema_length, ) diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 4f5dc87e61fb..d8e5aa9d403a 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -59,6 +59,7 @@ def read_ndjson( def scan_ndjson( source: str | Path | list[str] | list[Path], *, + schema: SchemaDefinition | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, batch_size: int | None = 1024, n_rows: int | None = None, @@ -66,7 +67,6 @@ def scan_ndjson( rechunk: bool = False, row_index_name: str | None = None, row_index_offset: int = 0, - schema: SchemaDefinition | None = None, ignore_errors: bool = False, ) -> LazyFrame: """ @@ -79,8 +79,19 @@ def scan_ndjson( ---------- source Path to a file. + schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict + The DataFrame schema may be declared in several ways: + + * As a dict of {name:type} pairs; if type is None, it will be auto-inferred. + * As a list of column names; in this case types are automatically inferred. + * As a list of (name,type) pairs; this is equivalent to the dictionary form. + + If you supply a list of column names that does not match the names in the + underlying data, the names given here will overwrite them. The number + of names given in the schema should match the underlying data dimensions. infer_schema_length - Infer the schema from the first `infer_schema_length` rows. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. batch_size Number of rows to read in each batch. n_rows @@ -94,16 +105,6 @@ def scan_ndjson( DataFrame row_index_offset Offset to start the row index column (only use if the name is set) - schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict - The DataFrame schema may be declared in several ways: - - * As a dict of {name:type} pairs; if type is None, it will be auto-inferred. - * As a list of column names; in this case types are automatically inferred. - * As a list of (name,type) pairs; this is equivalent to the dictionary form. - - If you supply a list of column names that does not match the names in the - underlying data, the names given here will overwrite them. The number - of names given in the schema should match the underlying data dimensions. ignore_errors Return `Null` if parsing fails because of schema mismatches. """ diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index c731a207a94e..1053e4b6f2cf 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -3,17 +3,24 @@ import re from contextlib import nullcontext from datetime import time -from io import BytesIO, StringIO +from io import BufferedReader, BytesIO, StringIO from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, NoReturn, Sequence, overload +from typing import IO, TYPE_CHECKING, Any, Callable, NoReturn, Sequence, overload import polars._reexport as pl from polars import functions as F -from polars.datatypes import FLOAT_DTYPES, Date, Datetime, Int64, Null, String +from polars.datatypes import ( + FLOAT_DTYPES, + NUMERIC_DTYPES, + Date, + Datetime, + Int64, + Null, + String, +) from polars.dependencies import import_optional from polars.exceptions import NoDataError, ParameterCollisionError -from polars.io._utils import _looks_like_url, _process_file_url +from polars.io._utils import PortableTemporaryFile, _looks_like_url, _process_file_url from polars.io.csv.functions import read_csv from polars.utils.deprecation import deprecate_renamed_parameter from polars.utils.various import normalize_filepath @@ -26,13 +33,13 @@ @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: str, engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -41,13 +48,13 @@ def read_excel( @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -56,13 +63,13 @@ def read_excel( @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: str, engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> NoReturn: @@ -73,13 +80,13 @@ def read_excel( # Literal[0] overlaps with the return value for other integers @overload # type: ignore[overload-overlap] def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: Literal[0] | Sequence[int], sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: @@ -88,13 +95,13 @@ def read_excel( @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -103,13 +110,13 @@ def read_excel( @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None, sheet_name: list[str] | tuple[str], engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: @@ -117,14 +124,15 @@ def read_excel( @deprecate_renamed_parameter("xlsx2csv_options", "engine_options", version="0.20.6") +@deprecate_renamed_parameter("read_csv_options", "read_options", version="0.20.7") def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int | Sequence[int] | None = None, sheet_name: str | list[str] | tuple[str] | None = None, engine: ExcelSpreadsheetEngine | None = None, engine_options: dict[str, Any] | None = None, - read_csv_options: dict[str, Any] | None = None, + read_options: dict[str, Any] | None = None, schema_overrides: SchemaDict | None = None, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: @@ -157,7 +165,7 @@ def read_excel( * "xlsx2csv": converts the data to an in-memory CSV before using the native polars `read_csv` method to parse the result. You can pass `engine_options` - and `read_csv_options` to refine the conversion. + and `read_options` to refine the conversion. * "openpyxl": this engine is significantly slower than `xlsx2csv` but supports additional automatic type inference; potentially useful if you are otherwise unable to parse your sheet with the (default) `xlsx2csv` engine in @@ -170,13 +178,19 @@ def read_excel( other options, using the `fastexcel` module to bind calamine. engine_options - Extra options passed to the underlying engine's Workbook-reading constructor. - For example, if using `xlsx2csv` you could pass `{"skip_empty_lines": True}`. - read_csv_options - Extra options passed to :func:`read_csv` for parsing the CSV file returned by - `xlsx2csv.Xlsx2csv().convert()`. This option is *only* applicable when using - the `xlsx2csv` engine. For example, you could pass ``{"has_header": False, - "new_columns": ["a", "b", "c"], "infer_schema_length": None}`` + Additional options passed to the underlying engine's primary parsing + constructor (given below), if supported: + + * "xlsx2csv": `Xlsx2csv` + * "openpyxl": `load_workbook` + * "pyxlsb": `open_workbook` + * "calamine": `n/a` + + read_options + Extra options passed to the function that reads the sheet data (for example, + the `read_csv` method if using the "xlsx2csv" engine, to which you could + pass ``{"infer_schema_length": None}``, or the `load_sheet_by_name` method + if using the "calamine" engine. schema_overrides Support type specification or override of one or more columns. raise_if_empty @@ -187,7 +201,7 @@ def read_excel( ----- When using the default `xlsx2csv` engine the target Excel sheet is first converted to CSV using `xlsx2csv.Xlsx2csv(source).convert()` and then parsed with Polars' - :func:`read_csv` function. You can pass additional options to `read_csv_options` + :func:`read_csv` function. You can pass additional options to `read_options` to influence this part of the parsing pipeline. Returns @@ -209,13 +223,13 @@ def read_excel( Read table data from sheet 3 in an Excel workbook as a DataFrame while skipping empty lines in the sheet. As sheet 3 does not have a header row and the default engine is `xlsx2csv` you can pass the necessary additional settings for this - to the "read_csv_options" parameter; these will be passed to :func:`read_csv`. + to the "read_options" parameter; these will be passed to :func:`read_csv`. >>> pl.read_excel( ... source="test.xlsx", ... sheet_id=3, ... engine_options={"skip_empty_lines": True}, - ... read_csv_options={"has_header": False, "new_columns": ["a", "b", "c"]}, + ... read_options={"has_header": False, "new_columns": ["a", "b", "c"]}, ... ) # doctest: +SKIP If the correct datatypes can't be determined you can use `schema_overrides` and/or @@ -227,14 +241,14 @@ def read_excel( >>> pl.read_excel( ... source="test.xlsx", - ... read_csv_options={"infer_schema_length": 1000}, + ... read_options={"infer_schema_length": 1000}, ... schema_overrides={"dt": pl.Date}, ... ) # doctest: +SKIP The `openpyxl` package can also be used to parse Excel data; it has slightly better default type detection, but is slower than `xlsx2csv`. If you have a sheet that is better read using this package you can set the engine as "openpyxl" (if you - use this engine then `read_csv_options` cannot be set). + use this engine then `read_options` cannot be set). >>> pl.read_excel( ... source="test.xlsx", @@ -242,17 +256,13 @@ def read_excel( ... schema_overrides={"dt": pl.Datetime, "value": pl.Int32}, ... ) # doctest: +SKIP """ - if engine and engine != "xlsx2csv" and read_csv_options: - msg = f"cannot specify `read_csv_options` when engine={engine!r}" - raise ValueError(msg) - return _read_spreadsheet( sheet_id, sheet_name, source=source, engine=engine, engine_options=engine_options, - read_csv_options=read_csv_options, + read_options=read_options, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) @@ -260,7 +270,7 @@ def read_excel( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: str, @@ -272,7 +282,7 @@ def read_ods( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: None = ..., @@ -284,7 +294,7 @@ def read_ods( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: str, @@ -296,7 +306,7 @@ def read_ods( @overload # type: ignore[overload-overlap] def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: Literal[0] | Sequence[int], sheet_name: None = ..., @@ -308,7 +318,7 @@ def read_ods( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: None = ..., @@ -320,7 +330,7 @@ def read_ods( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None, sheet_name: list[str] | tuple[str], @@ -331,7 +341,7 @@ def read_ods( def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int | Sequence[int] | None = None, sheet_name: str | list[str] | tuple[str] | None = None, @@ -390,33 +400,82 @@ def read_ods( source=source, engine="ods", engine_options={}, - read_csv_options={}, + read_options={}, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) +def _identify_from_magic_bytes(data: IO[bytes] | bytes) -> str | None: + if isinstance(data, bytes): + data = BytesIO(data) + + xls_bytes = b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1" # excel 97-2004 + xlsx_bytes = b"PK\x03\x04" # xlsx/openoffice + + initial_position = data.tell() + try: + magic_bytes = data.read(8) + if magic_bytes == xls_bytes: + return "xls" + elif magic_bytes[:4] == xlsx_bytes: + return "xlsx" + return None + finally: + data.seek(initial_position) + + +def _identify_workbook(wb: str | Path | IO[bytes] | bytes) -> str | None: + """Use file extension (and magic bytes) to identify Workbook type.""" + if not isinstance(wb, (str, Path)): + # raw binary data (bytesio, etc) + return _identify_from_magic_bytes(wb) + else: + p = Path(wb) + ext = p.suffix[1:].lower() + + # unambiguous file extensions + if ext in ("xlsx", "xlsm", "xlsb"): + return ext + elif ext[:2] == "od": + return "ods" + + # check magic bytes to resolve ambiguity (eg: xls/xlsx, or no extension) + with p.open("rb") as f: + magic_bytes = BytesIO(f.read(8)) + return _identify_from_magic_bytes(magic_bytes) + + def _read_spreadsheet( sheet_id: int | Sequence[int] | None, sheet_name: str | list[str] | tuple[str] | None, - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, engine: ExcelSpreadsheetEngine | Literal["ods"] | None, engine_options: dict[str, Any] | None = None, - read_csv_options: dict[str, Any] | None = None, + read_options: dict[str, Any] | None = None, schema_overrides: SchemaDict | None = None, *, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: - if isinstance(source, (str, Path)): + if is_file := isinstance(source, (str, Path)): source = normalize_filepath(source) if _looks_like_url(source): source = _process_file_url(source) if engine is None: - if (src := str(source).lower()).endswith(".ods"): - engine = "ods" + if is_file and str(source).lower().endswith(".ods"): + # note: engine cannot be 'None' here (if called from read_ods) + msg = "OpenDocumentSpreadsheet files require use of `read_ods`, not `read_excel`" + raise ValueError(msg) + + # note: eventually want 'calamine' to be the default for all extensions + file_type = _identify_workbook(source) + if file_type == "xlsb": + engine = "pyxlsb" + elif file_type == "xls": + engine = "calamine" else: - engine = "pyxlsb" if src.endswith(".xlsb") else "xlsx2csv" + engine = "xlsx2csv" # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( @@ -429,8 +488,8 @@ def _read_spreadsheet( name: reader_fn( parser=parser, sheet_name=name, - read_csv_options=read_csv_options, schema_overrides=schema_overrides, + read_options=(read_options or {}), raise_if_empty=raise_if_empty, ) for name in sheet_names @@ -458,6 +517,7 @@ def _get_sheet_names( if sheet_id is not None and sheet_name is not None: msg = f"cannot specify both `sheet_name` ({sheet_name!r}) and `sheet_id` ({sheet_id!r})" raise ValueError(msg) + sheet_names = [] if sheet_id is None and sheet_name is None: sheet_names.append(worksheets[0]["name"]) @@ -497,10 +557,13 @@ def _get_sheet_names( def _initialise_spreadsheet_parser( engine: str | None, - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, engine_options: dict[str, Any], ) -> tuple[Callable[..., pl.DataFrame], Any, list[dict[str, Any]]]: """Instantiate the indicated spreadsheet parser and establish related properties.""" + if isinstance(source, (str, Path)) and not Path(source).exists(): + raise FileNotFoundError(source) + if engine == "xlsx2csv": # default xlsx2csv = import_optional("xlsx2csv") @@ -525,15 +588,17 @@ def _initialise_spreadsheet_parser( elif engine == "calamine": # note: can't read directly from bytes (yet) so - if read_bytesio := isinstance(source, BytesIO): - temp_data = NamedTemporaryFile(delete=True) - with nullcontext() if not read_bytesio else temp_data as tmp: # type: ignore[attr-defined] - if read_bytesio: - tmp.write(source.getvalue()) # type: ignore[union-attr] - source = temp_data.name + read_buffered = False + if read_bytesio := isinstance(source, BytesIO) or ( + read_buffered := isinstance(source, BufferedReader) + ): + temp_data = PortableTemporaryFile(delete=True) - if not Path(source).exists(): # type: ignore[arg-type] - raise FileNotFoundError(source) + with temp_data if (read_bytesio or read_buffered) else nullcontext() as tmp: + if read_bytesio and tmp is not None: + tmp.write(source.read() if read_buffered else source.getvalue()) # type: ignore[union-attr] + source = tmp.name + tmp.close() fxl = import_optional("fastexcel", min_version="0.7.0") parser = fxl.read_excel(source, **engine_options) @@ -571,7 +636,7 @@ def _initialise_spreadsheet_parser( def _csv_buffer_to_frame( csv: StringIO, separator: str, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -587,23 +652,23 @@ def _csv_buffer_to_frame( raise NoDataError(msg) return pl.DataFrame() - if read_csv_options is None: - read_csv_options = {} + if read_options is None: + read_options = {} if schema_overrides: - if (csv_dtypes := read_csv_options.get("dtypes", {})) and set( + if (csv_dtypes := read_options.get("dtypes", {})) and set( csv_dtypes ).intersection(schema_overrides): - msg = "cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" + msg = "cannot specify columns in both `schema_overrides` and `read_options['dtypes']`" raise ParameterCollisionError(msg) - read_csv_options = read_csv_options.copy() - read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} + read_options = read_options.copy() + read_options["dtypes"] = {**csv_dtypes, **schema_overrides} # otherwise rewind the buffer and parse as csv csv.seek(0) df = read_csv( csv, separator=separator, - **read_csv_options, + **read_options, ) return _drop_null_data(df, raise_if_empty=raise_if_empty) @@ -616,7 +681,14 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: # will be named as "_duplicated_{n}" (or "__UNNAMED__{n}" from calamine) if col_name == "" or re.match(r"(_duplicated_|__UNNAMED__)\d+$", col_name): col = df[col_name] - if col.dtype == Null or col.null_count() == len(df): + if ( + col.dtype == Null + or col.null_count() == len(df) + or ( + col.dtype in NUMERIC_DTYPES + and col.replace(0, None).null_count() == len(df) + ) + ): null_cols.append(col_name) if null_cols: df = df.drop(*null_cols) @@ -637,7 +709,7 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: def _read_spreadsheet_ods( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -705,7 +777,7 @@ def _read_spreadsheet_ods( def _read_spreadsheet_openpyxl( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -753,12 +825,12 @@ def _read_spreadsheet_openpyxl( def _read_spreadsheet_calamine( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, ) -> pl.DataFrame: - ws = parser.load_sheet_by_name(sheet_name) + ws = parser.load_sheet_by_name(sheet_name, **read_options) df = ws.to_polars() if schema_overrides: @@ -766,26 +838,27 @@ def _read_spreadsheet_calamine( df = _drop_null_data(df, raise_if_empty=raise_if_empty) - # calamine may read integer data as float; cast back to int where possible. - # do a similar downcast check for datetime -> date dtypes. + # refine dtypes type_checks = [] for c, dtype in df.schema.items(): + # may read integer data as float; cast back to int where possible. if dtype in FLOAT_DTYPES: - check_cast = [F.col(c).floor().eq_missing(F.col(c)), F.col(c).cast(Int64)] + check_cast = [F.col(c).floor().eq(F.col(c)), F.col(c).cast(Int64)] type_checks.append(check_cast) + # do a similar check for datetime columns that have only 00:00:00 times. elif dtype == Datetime: check_cast = [ - F.col(c).drop_nulls().dt.time().eq_missing(time(0, 0, 0)), + F.col(c).dt.time().eq(time(0, 0, 0)), F.col(c).cast(Date), ] type_checks.append(check_cast) if type_checks: - apply_downcast = df.select([d[0] for d in type_checks]).row(0) - - # do a similar check for datetime columns that have only 00:00:00 times. + apply_cast = df.select( + [d[0].all(ignore_nulls=True) for d in type_checks], + ).row(0) if downcast := [ - cast for apply, (_, cast) in zip(apply_downcast, type_checks) if apply + cast for apply, (_, cast) in zip(apply_cast, type_checks) if apply ]: df = df.with_columns(*downcast) @@ -795,7 +868,7 @@ def _read_spreadsheet_calamine( def _read_spreadsheet_pyxlsb( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -850,7 +923,7 @@ def _read_spreadsheet_pyxlsb( def _read_spreadsheet_xlsx2csv( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -861,14 +934,14 @@ def _read_spreadsheet_xlsx2csv( outfile=csv_buffer, sheetname=sheet_name, ) - if read_csv_options is None: - read_csv_options = {} - read_csv_options.setdefault("truncate_ragged_lines", True) + if read_options is None: + read_options = {} + read_options.setdefault("truncate_ragged_lines", True) return _csv_buffer_to_frame( csv_buffer, separator=",", - read_csv_options=read_csv_options, + read_options=read_options, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 2881d4f47f1a..2be1a35a75b3 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -53,7 +53,7 @@ is_polars_dtype, py_type_to_dtype, ) -from polars.dependencies import dataframe_api_compat, subprocess +from polars.dependencies import subprocess from polars.io._utils import _is_local_file, _is_supported_cloud from polars.io.csv._utils import _check_arg_is_1byte from polars.io.ipc.anonymous_scan import _scan_ipc_fsspec @@ -170,9 +170,11 @@ class LazyFrame: Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. - infer_schema_length : int, default None - Maximum number of rows to read for schema inference; only applies if the input - data is a sequence or generator of rows; other input is read as-is. + infer_schema_length : int or None + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. + This parameter only applies if the input data is a sequence or generator of + rows; other input is read as-is. nan_to_null : bool, default False If the data comes from one or more numpy arrays, can optionally convert input data np.nan values to null instead. This is a no-op for all other input data. @@ -284,6 +286,7 @@ class LazyFrame: └─────┴─────┴─────┘ """ + __slots__ = ("_ldf",) _ldf: PyLazyFrame _accessors: ClassVar[set[str]] = set() @@ -534,7 +537,7 @@ def _scan_ndjson( cls, source: str | Path | list[str] | list[Path], *, - infer_schema_length: int | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, schema: SchemaDefinition | None = None, batch_size: int | None = None, n_rows: int | None = None, @@ -693,19 +696,6 @@ def schema(self) -> OrderedDict[str, DataType]: """ return OrderedDict(self._ldf.schema()) - def __dataframe_consortium_standard__( - self, *, api_version: str | None = None - ) -> Any: - """ - Provide entry point to the Consortium DataFrame Standard API. - - This is developed and maintained outside of polars. - Please report any issues to https://github.com/data-apis/dataframe-api-compat. - """ - return dataframe_api_compat.polars_standard.convert_to_standard_compliant_dataframe( - self, api_version=api_version - ) - @property def width(self) -> int: """ @@ -1012,10 +1002,11 @@ def describe( Customize which percentiles are displayed, applying linear interpolation: - >>> lf.describe( - ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], - ... interpolation="linear", - ... ) + >>> with pl.Config(tbl_rows=12): + ... lf.describe( + ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], + ... interpolation="linear", + ... ) shape: (11, 7) ┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐ │ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │ @@ -1152,6 +1143,7 @@ def explain( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = False, + tree_format: bool = False, ) -> str: """ Create a string representation of the query plan. @@ -1181,6 +1173,8 @@ def explain( Common subexpressions will be cached and reused. streaming Run parts of the query in a streaming fashion (this is in an alpha state) + tree_format + Format the output as a tree Examples -------- @@ -1207,7 +1201,12 @@ def explain( streaming, _eager=False, ) + if tree_format: + return ldf.describe_optimized_plan_tree() return ldf.describe_optimized_plan() + + if tree_format: + return self._ldf.describe_plan_tree() return self._ldf.describe_plan() def show_graph( @@ -3126,27 +3125,29 @@ def select_seq( ) return self._from_pyldf(self._ldf.select_seq(pyexprs)) + @deprecate_parameter_as_positional("by", version="0.20.7") def group_by( self, - by: IntoExpr | Iterable[IntoExpr], - *more_by: IntoExpr, + *by: IntoExpr | Iterable[IntoExpr], maintain_order: bool = False, + **named_by: IntoExpr, ) -> LazyGroupBy: """ Start a group by operation. Parameters ---------- - by + *by Column(s) to group by. Accepts expression input. Strings are parsed as column names. - *more_by - Additional columns to group by, specified as positional arguments. maintain_order Ensure that the order of the groups is consistent with the input data. This is slower than a default group by. Setting this to `True` blocks the possibility to run on the streaming engine. + **named_by + Additional columns to group by, specified as keyword arguments. + The columns will be renamed to the keyword used. Examples -------- @@ -3219,7 +3220,7 @@ def group_by( │ c ┆ 1 ┆ 1.0 │ └─────┴─────┴─────┘ """ - exprs = parse_as_list_of_expressions(by, *more_by) + exprs = parse_as_list_of_expressions(*by, **named_by) lgb = self._ldf.group_by(exprs, maintain_order) return LazyGroupBy(lgb) @@ -3234,7 +3235,7 @@ def rolling( check_sorted: bool = True, ) -> LazyGroupBy: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. Different from a `dynamic_group_by` the windows are now determined by the individual values and are not of constant intervals. For constant intervals @@ -3278,11 +3279,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 - Parameters ---------- index_column @@ -3292,8 +3288,8 @@ def rolling( then it must be sorted in ascending order within each group). In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -3953,7 +3949,7 @@ def join( * *outer_coalesce* Same as 'outer', but coalesces the key columns * *cross* - Returns the cartisian product of rows from both tables + Returns the Cartesian product of rows from both tables * *semi* Filter rows that have a match in the right table. * *anti* @@ -4814,10 +4810,16 @@ def first(self) -> Self: """ return self.slice(0, 1) + @deprecate_function( + "Use `select(pl.all().approx_n_unique())` instead.", version="0.20.11" + ) def approx_n_unique(self) -> Self: """ Approximate count of unique values. + .. deprecated:: 0.20.11 + Use `select(pl.all().approx_n_unique())` instead. + This is done using the HyperLogLog++ algorithm for cardinality estimation. Examples @@ -4828,7 +4830,7 @@ def approx_n_unique(self) -> Self: ... "b": [1, 2, 1, 1], ... } ... ) - >>> lf.approx_n_unique().collect() + >>> lf.approx_n_unique().collect() # doctest: +SKIP shape: (1, 2) ┌─────┬─────┐ │ a ┆ b │ @@ -5446,7 +5448,7 @@ def explode( ---------- columns Column names, expressions, or a selector defining them. The underlying - columns being exploded must be of List or String datatype. + columns being exploded must be of the `List` or `Array` data type. *more_columns Additional names of columns to explode, specified as positional arguments. diff --git a/py-polars/polars/lazyframe/group_by.py b/py-polars/polars/lazyframe/group_by.py index b8e3aa588c7c..ca6b712bc481 100644 --- a/py-polars/polars/lazyframe/group_by.py +++ b/py-polars/polars/lazyframe/group_by.py @@ -347,16 +347,16 @@ def len(self) -> LazyFrame: ... "b": [1, None, 2], ... } ... ) - >>> lf.group_by("a").count().collect() # doctest: +SKIP + >>> lf.group_by("a").len().collect() # doctest: +SKIP shape: (2, 2) - ┌────────┬───────┐ - │ a ┆ count │ - │ --- ┆ --- │ - │ str ┆ u32 │ - ╞════════╪═══════╡ - │ apple ┆ 2 │ - │ orange ┆ 1 │ - └────────┴───────┘ + ┌────────┬─────┐ + │ a ┆ len │ + │ --- ┆ --- │ + │ str ┆ u32 │ + ╞════════╪═════╡ + │ apple ┆ 2 │ + │ orange ┆ 1 │ + └────────┴─────┘ """ return self.agg(F.len()) diff --git a/py-polars/polars/meta/__init__.py b/py-polars/polars/meta/__init__.py new file mode 100644 index 000000000000..b9e84653ebc8 --- /dev/null +++ b/py-polars/polars/meta/__init__.py @@ -0,0 +1,13 @@ +"""Public functions that provide information about the Polars package or the environment it runs in.""" # noqa: W505 +from polars.meta.build import build_info +from polars.meta.index_type import get_index_type +from polars.meta.thread_pool import thread_pool_size, threadpool_size +from polars.meta.versions import show_versions + +__all__ = [ + "build_info", + "get_index_type", + "show_versions", + "thread_pool_size", + "threadpool_size", +] diff --git a/py-polars/polars/meta/build.py b/py-polars/polars/meta/build.py new file mode 100644 index 000000000000..d38d92fc4414 --- /dev/null +++ b/py-polars/polars/meta/build.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +from polars.utils._polars_version import get_polars_version + +try: + from polars.polars import __build__ +except ImportError: + __build__ = {} + +__build__["version"] = get_polars_version() or "" + + +def build_info() -> dict[str, Any]: + """ + Return detailed Polars build information. + + The dictionary with build information contains the following keys: + + - `"build"` + - `"info-time"` + - `"dependencies"` + - `"features"` + - `"host"` + - `"target"` + - `"git"` + - `"version"` + + If Polars was compiled without the `build_info` feature flag, only the `"version"` + key is included. + + Notes + ----- + `pyo3-built`_ is used to generate the build information. + + .. _pyo3-built: https://github.com/PyO3/pyo3-built + """ + return __build__ diff --git a/py-polars/polars/meta/index_type.py b/py-polars/polars/meta/index_type.py new file mode 100644 index 000000000000..2a8d91f32377 --- /dev/null +++ b/py-polars/polars/meta/index_type.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr + +if TYPE_CHECKING: + from polars.datatypes import DataType + + +def get_index_type() -> DataType: + """ + Return the data type used for Polars indexing. + + Returns + ------- + DataType + :class:`UInt32` in regular Polars, :class:`UInt64` in bigidx Polars. + + Examples + -------- + >>> pl.get_index_type() + UInt32 + """ + return plr.get_index_type() diff --git a/py-polars/polars/meta/thread_pool.py b/py-polars/polars/meta/thread_pool.py new file mode 100644 index 000000000000..446eb486ceb2 --- /dev/null +++ b/py-polars/polars/meta/thread_pool.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import contextlib + +from polars.utils.deprecation import deprecate_renamed_function + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr + + +def thread_pool_size() -> int: + """ + Return the number of threads in the Polars thread pool. + + Notes + ----- + The thread pool size can be overridden by setting the `POLARS_MAX_THREADS` + environment variable before process start. The thread pool is not behind a + lock, so it cannot be modified once set. A reasonable use case for this might + be temporarily limiting the number of threads before importing Polars in a + PySpark UDF or similar context. Otherwise, it is strongly recommended not to + override this value as it will be set automatically by the engine. + + Examples + -------- + >>> pl.thread_pool_size() # doctest: +SKIP + 16 + """ + return plr.thread_pool_size() + + +@deprecate_renamed_function("thread_pool_size", version="0.20.7") +def threadpool_size() -> int: + """ + Return the number of threads in the Polars thread pool. + + .. deprecated:: 0.20.7 + This function has been renamed to :func:`thread_pool_size`. + """ + return thread_pool_size() diff --git a/py-polars/polars/utils/show_versions.py b/py-polars/polars/meta/versions.py similarity index 89% rename from py-polars/polars/utils/show_versions.py rename to py-polars/polars/meta/versions.py index d58bd860969e..305b2e2a17f8 100644 --- a/py-polars/polars/utils/show_versions.py +++ b/py-polars/polars/meta/versions.py @@ -2,13 +2,13 @@ import sys -from polars.utils.meta import get_index_type -from polars.utils.polars_version import get_polars_version +from polars.meta.index_type import get_index_type +from polars.utils._polars_version import get_polars_version def show_versions() -> None: - r""" - Print out version of Polars and dependencies to stdout. + """ + Print out the version of Polars and its optional dependencies. Examples -------- @@ -38,8 +38,8 @@ def show_versions() -> None: xlsx2csv: 0.8.1 xlsxwriter: 3.1.9 """ # noqa: W505 - # note: we import 'platform' here (rather than at the top of the - # module) as a micro-optimisation for polars' initial import + # Note: we import 'platform' here (rather than at the top of the + # module) as a micro-optimization for polars' initial import import platform deps = _get_dependency_info() diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 0793f5b9ee57..846dd4472bd6 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -242,6 +242,7 @@ def _combine_as_selector( class _selector_proxy_(Expr): """Base column selector expression/proxy.""" + __slots__ = ("_attrs", "_repr_override") _attrs: dict[str, Any] _repr_override: str diff --git a/py-polars/polars/series/_numpy.py b/py-polars/polars/series/_numpy.py index 5b07af187005..6163172fc478 100644 --- a/py-polars/polars/series/_numpy.py +++ b/py-polars/polars/series/_numpy.py @@ -1,24 +1,21 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np -if TYPE_CHECKING: - from polars import Series - # https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array class SeriesView(np.ndarray): # type: ignore[type-arg] def __new__( - cls, input_array: np.ndarray[Any, Any], owned_series: Series + cls, input_array: np.ndarray[Any, Any], owned_object: Any ) -> SeriesView: # Input array is an already formed ndarray instance # We first cast to be our class type obj = input_array.view(cls) # add the new attribute to the created instance - obj.owned_series = owned_series + obj.owned_series = owned_object # Finally, we must return the newly created object: return obj @@ -30,7 +27,9 @@ def __array_finalize__(self, obj: Any) -> None: # https://stackoverflow.com/questions/4355524/getting-data-from-ctypes-array-into-numpy -def _ptr_to_numpy(ptr: int, len: int, ptr_type: Any) -> np.ndarray[Any, Any]: +def _ptr_to_numpy( + ptr: int, shape: int | tuple[int, int] | tuple[int], ptr_type: Any +) -> np.ndarray[Any, Any]: """ Create a memory block view as a numpy array. @@ -38,8 +37,8 @@ def _ptr_to_numpy(ptr: int, len: int, ptr_type: Any) -> np.ndarray[Any, Any]: ---------- ptr C/Rust ptr casted to usize. - len - Length of the array values. + shape + Shape of the array values. ptr_type Example: f32: ctypes.c_float) @@ -50,4 +49,6 @@ def _ptr_to_numpy(ptr: int, len: int, ptr_type: Any) -> np.ndarray[Any, Any]: View of memory block as numpy array. """ ptr_ctype = ctypes.cast(ptr, ctypes.POINTER(ptr_type)) - return np.ctypeslib.as_array(ptr_ctype, (len,)) + if isinstance(shape, int): + shape = (shape,) + return np.ctypeslib.as_array(ptr_ctype, shape) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index ea7842657036..4a547485f962 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -1,8 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Sequence +from polars import functions as F from polars.series.utils import expr_dispatch +from polars.utils._wrap import wrap_s if TYPE_CHECKING: from datetime import date, datetime, time @@ -75,6 +77,54 @@ def sum(self) -> Series: └─────┘ """ + def std(self, ddof: int = 1) -> Series: + """ + Compute the std of the values of the sub-arrays. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s.arr.std() + shape: (2,) + Series: 'a' [f64] + [ + 0.707107 + 0.707107 + ] + """ + + def var(self, ddof: int = 1) -> Series: + """ + Compute the var of the values of the sub-arrays. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s.arr.var() + shape: (2,) + Series: 'a' [f64] + [ + 0.5 + 0.5 + ] + """ + + def median(self) -> Series: + """ + Compute the median of the values of the sub-arrays. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s.arr.median() + shape: (2,) + Series: 'a' [f64] + [ + 1.5 + 3.5 + ] + """ + def unique(self, *, maintain_order: bool = False) -> Series: """ Get the unique/distinct values in the array. @@ -458,3 +508,96 @@ def count_matches(self, element: IntoExpr) -> Series: ] """ + + def to_struct( + self, + fields: Callable[[int], str] | Sequence[str] | None = None, + ) -> Series: + """ + Convert the series of type `Array` to a series of type `Struct`. + + Parameters + ---------- + fields + If the name and number of the desired fields is known in advance + a list of field names can be given, which will be assigned by index. + Otherwise, to dynamically assign field names, a custom function can be + used; if neither are set, fields will be `field_0, field_1 .. field_n`. + + Examples + -------- + Convert array to struct with default field name assignment: + + >>> s1 = pl.Series("n", [[0, 1, 2], [3, 4, 5]], dtype=pl.Array(pl.Int8, 3)) + >>> s2 = s1.arr.to_struct() + >>> s2 + shape: (2,) + Series: 'n' [struct[3]] + [ + {0,1,2} + {3,4,5} + ] + >>> s2.struct.fields + ['field_0', 'field_1', 'field_2'] + + Convert array to struct with field name assignment by function/index: + + >>> s3 = s1.arr.to_struct(fields=lambda idx: f"n{idx:02}") + >>> s3.struct.fields + ['n00', 'n01', 'n02'] + + Convert array to struct with field name assignment by + index from a list of names: + + >>> s1.arr.to_struct(fields=["one", "two", "three"]).struct.unnest() + shape: (2, 3) + ┌─────┬─────┬───────┐ + │ one ┆ two ┆ three │ + │ --- ┆ --- ┆ --- │ + │ i8 ┆ i8 ┆ i8 │ + ╞═════╪═════╪═══════╡ + │ 0 ┆ 1 ┆ 2 │ + │ 3 ┆ 4 ┆ 5 │ + └─────┴─────┴───────┘ + """ + s = wrap_s(self._s) + return s.to_frame().select(F.col(s.name).arr.to_struct(fields)).to_series() + + def shift(self, n: int | IntoExprColumn = 1) -> Series: + """ + Shift array values by the given number of indices. + + Parameters + ---------- + n + Number of indices to shift forward. If a negative value is passed, values + are shifted in the opposite direction instead. + + Notes + ----- + This method is similar to the `LAG` operation in SQL when the value for `n` + is positive. With a negative value for `n`, it is similar to `LEAD`. + + Examples + -------- + By default, array values are shifted forward by one index. + + >>> s = pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)) + >>> s.arr.shift() + shape: (2,) + Series: '' [array[i64, 3]] + [ + [null, 1, 2] + [null, 4, 5] + ] + + Pass a negative value to shift in the opposite direction instead. + + >>> s.arr.shift(-2) + shape: (2,) + Series: '' [array[i64, 3]] + [ + [3, null, null] + [6, null, null] + ] + """ diff --git a/py-polars/polars/series/binary.py b/py-polars/polars/series/binary.py index ad7ea1f4823c..2796ecb403eb 100644 --- a/py-polars/polars/series/binary.py +++ b/py-polars/polars/series/binary.py @@ -20,7 +20,7 @@ def __init__(self, series: Series): self._s: PySeries = series._s def contains(self, literal: IntoExpr) -> Series: - """ + r""" Check if binaries in Series contain a binary substring. Parameters @@ -32,31 +32,67 @@ def contains(self, literal: IntoExpr) -> Series: ------- Series Series of data type :class:`Boolean`. + + Examples + -------- + >>> s = pl.Series("colors", [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"]) + >>> s.bin.contains(b"\xff") + shape: (3,) + Series: 'colors' [bool] + [ + false + true + true + ] """ def ends_with(self, suffix: IntoExpr) -> Series: - """ + r""" Check if string values end with a binary substring. Parameters ---------- suffix Suffix substring. + + Examples + -------- + >>> s = pl.Series("colors", [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"]) + >>> s.bin.ends_with(b"\x00") + shape: (3,) + Series: 'colors' [bool] + [ + true + true + false + ] """ def starts_with(self, prefix: IntoExpr) -> Series: - """ + r""" Check if values start with a binary substring. Parameters ---------- prefix Prefix substring. + + Examples + -------- + >>> s = pl.Series("colors", [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"]) + >>> s.bin.starts_with(b"\x00") + shape: (3,) + Series: 'colors' [bool] + [ + true + false + true + ] """ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -65,11 +101,54 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Series + Series of data type :class:`String`. + + Examples + -------- + Decode values using hexadecimal encoding. + + >>> s = pl.Series("colors", [b"000000", b"ffff00", b"0000ff"]) + >>> s.bin.decode("hex") + shape: (3,) + Series: 'colors' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + b"\x00\x00\xff" + ] + + Decode values using Base64 encoding. + + >>> s = pl.Series("colors", [b"AAAA", b"//8A", b"AAD/"]) + >>> s.bin.decode("base64") + shape: (3,) + Series: 'colors' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + b"\x00\x00\xff" + ] + + Set `strict=False` to set invalid values to null instead of raising an error. + + >>> s = pl.Series("colors", [b"000000", b"ffff00", b"invalid_value"]) + >>> s.bin.decode("hex", strict=False) + shape: (3,) + Series: 'colors' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + null + ] """ def encode(self, encoding: TransferEncoding) -> Series: - """ - Encode a value using the provided encoding. + r""" + Encode values using the provided encoding. Parameters ---------- @@ -79,5 +158,30 @@ def encode(self, encoding: TransferEncoding) -> Series: Returns ------- Series - Series of data type :class:`Boolean`. + Series of data type :class:`String`. + + Examples + -------- + Encode values using hexadecimal encoding. + + >>> s = pl.Series("colors", [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"]) + >>> s.bin.encode("hex") + shape: (3,) + Series: 'colors' [str] + [ + "000000" + "ffff00" + "0000ff" + ] + + Encode values using Base64 encoding. + + >>> s.bin.encode("base64") + shape: (3,) + Series: 'colors' [str] + [ + "AAAA" + "//8A" + "AAD/" + ] """ diff --git a/py-polars/polars/series/categorical.py b/py-polars/polars/series/categorical.py index 6ebaec6b7edb..03057ea81f98 100644 --- a/py-polars/polars/series/categorical.py +++ b/py-polars/polars/series/categorical.py @@ -43,7 +43,7 @@ def set_ordering(self, ordering: CategoricalOrdering) -> Series: Ordering type: - 'physical' -> Use the physical representation of the categories to - determine the order (default). + determine the order (default). - 'lexical' -> Use the string values to determine the ordering. """ diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 673bcdd4ee4e..8980f53426d3 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from polars.datatypes import Date, Datetime +from polars.datatypes import Date, Datetime, Duration from polars.series.utils import expr_dispatch from polars.utils._wrap import wrap_s from polars.utils.convert import _to_python_date, _to_python_datetime @@ -82,7 +82,7 @@ def median(self) -> TemporalLiteral | float | None: if out is not None: if s.dtype == Date: return _to_python_date(int(out)) # type: ignore[arg-type] - elif s.dtype == Datetime: + elif s.dtype in (Datetime, Duration): return out # type: ignore[return-value] else: return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] @@ -106,7 +106,7 @@ def mean(self) -> TemporalLiteral | float | None: if out is not None: if s.dtype == Date: return _to_python_date(int(out)) # type: ignore[arg-type] - elif s.dtype == Datetime: + elif s.dtype in (Datetime, Duration): return out # type: ignore[return-value] else: return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] @@ -1093,6 +1093,11 @@ def convert_time_zone(self, time_zone: str) -> Series: time_zone Time zone for the `Datetime` Series. + Notes + ----- + If converting from a time-zone-naive datetime, then conversion will happen + as if converting from UTC, regardless of your system's time zone. + Examples -------- >>> from datetime import datetime diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 529a47147af5..e7c5eb2f0828 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -233,6 +233,15 @@ def mean(self) -> Series: ] """ + def median(self) -> Series: + """Compute the median value of the arrays in the list.""" + + def std(self) -> Series: + """Compute the std value of the arrays in the list.""" + + def var(self) -> Series: + """Compute the var value of the arrays in the list.""" + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Series: """ Sort the arrays in this column. @@ -300,6 +309,22 @@ def unique(self, *, maintain_order: bool = False) -> Series: ] """ + def n_unique(self) -> Series: + """ + Count the number of unique values in every sub-lists. + + Examples + -------- + >>> s = pl.Series("a", [[1, 1, 2], [2, 3, 4]]) + >>> s.list.n_unique() + shape: (2,) + Series: 'a' [u32] + [ + 2 + 3 + ] + """ + def concat(self, other: list[Series] | Series | list[Any]) -> Series: """ Concat the arrays in a Series dtype List in linear time. @@ -383,6 +408,32 @@ def gather( ] """ + def gather_every( + self, n: int | IntoExprColumn, offset: int | IntoExprColumn = 0 + ) -> Series: + """ + Take every n-th value start from offset in sublists. + + Parameters + ---------- + n + Gather every n-th element. + offset + Starting index. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2, 3], [], [6, 7, 8, 9]]) + >>> s.list.gather_every(2, offset=1) + shape: (3,) + Series: 'a' [list[i64]] + [ + [2] + [] + [7, 9] + ] + """ + def __getitem__(self, item: int) -> Series: return self.get(item) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index fb8d347a5a29..d4f2e3feb27a 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2,7 +2,6 @@ import contextlib import math -import os from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal from typing import ( @@ -18,7 +17,6 @@ NoReturn, Sequence, Union, - cast, overload, ) @@ -43,6 +41,7 @@ Object, String, Time, + UInt8, UInt32, UInt64, Unknown, @@ -59,13 +58,13 @@ _check_for_numpy, _check_for_pandas, _check_for_pyarrow, - dataframe_api_compat, hvplot, ) from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa from polars.exceptions import ModuleUpgradeRequired, ShapeError +from polars.meta import get_index_type from polars.series.array import ArrayNameSpace from polars.series.binary import BinaryNameSpace from polars.series.categorical import CatNameSpace @@ -99,9 +98,9 @@ deprecate_renamed_parameter, issue_deprecation_warning, ) -from polars.utils.meta import get_index_type from polars.utils.unstable import unstable from polars.utils.various import ( + BUILDING_SPHINX_DOCS, _is_generator, no_default, parse_version, @@ -149,7 +148,7 @@ from typing import Self else: from typing_extensions import Self -elif os.getenv("BUILDING_SPHINX_DOCS"): +elif BUILDING_SPHINX_DOCS: property = sphinx_accessor ArrayLike = Union[ @@ -187,9 +186,9 @@ class Series: Data type of the Series if `values` contains no non-null data. .. deprecated:: 0.20.6 - The data type for empty Series will always be Null. - To preserve behavior, check if the resulting Series has data type Null and - cast to the desired data type. + The data type for empty Series will always be `Null`, unless `dtype` is + specified. To preserve behavior, check if the resulting Series has data type + `Null` and cast to the desired data type. This parameter will be removed in the next breaking release. Examples @@ -238,7 +237,8 @@ class Series: ] """ - _s: PySeries = None + __slots__ = ("_s",) + _s: PySeries _accessors: ClassVar[set[str]] = { "arr", "cat", @@ -263,23 +263,24 @@ def __init__( if dtype_if_empty != Null: issue_deprecation_warning( "The `dtype_if_empty` parameter for the Series constructor is deprecated." - " The data type for empty Series will always be Null." + " The data type for empty Series will always be Null, unless `dtype` is specified." " To preserve behavior, check if the resulting Series has data type Null and cast to the desired data type." " This parameter will be removed in the next breaking release.", version="0.20.6", ) - # If 'Unknown' treat as None to attempt inference + # If 'Unknown' treat as None to trigger type inference if dtype == Unknown: dtype = None - # Raise early error on invalid dtype - elif ( - dtype is not None - and not is_polars_dtype(dtype) - and py_type_to_dtype(dtype, raise_unmatched=False) is None - ): - msg = f"given dtype: {dtype!r} is not a valid Polars data type and cannot be converted into one" - raise ValueError(msg) + elif dtype is not None and not is_polars_dtype(dtype): + # Raise early error on invalid dtype + if not is_polars_dtype( + pl_dtype := py_type_to_dtype(dtype, raise_unmatched=False) + ): + msg = f"given dtype: {dtype!r} is not a valid Polars data type and cannot be converted into one" + raise ValueError(msg) + else: + dtype = pl_dtype # Handle case where values are passed as the first argument original_name: str | None = None @@ -331,7 +332,7 @@ def __init__( self._s = arrow_to_pyseries(name, values) elif _check_for_pandas(values) and isinstance( - values, (pd.Series, pd.DatetimeIndex) + values, (pd.Series, pd.Index, pd.DatetimeIndex) ): self._s = pandas_to_pyseries(name, values) @@ -370,11 +371,23 @@ def _from_arrow(cls, name: str, values: pa.Array, *, rechunk: bool = True) -> Se """Construct a Series from an Arrow Array.""" return cls._from_pyseries(arrow_to_pyseries(name, values, rechunk=rechunk)) + @classmethod + def _import_from_c(cls, name: str, pointers: list[tuple[int, int]]) -> Self: + """ + Construct a Series from Arrows C interface. + + Warning + ------- + This will read the `array` pointer without moving it. The host process should + garbage collect the heap pointer, but not its contents. + """ + return cls._from_pyseries(PySeries._import_from_c(name, pointers)) + @classmethod def _from_pandas( cls, name: str, - values: pd.Series[Any] | pd.DatetimeIndex, + values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, nan_to_null: bool = True, ) -> Self: @@ -486,9 +499,9 @@ def _from_buffers( the physical data type of `dtype`. Some data types require multiple buffers: - `String`: A data buffer of type `UInt8` and an offsets buffer - of type `Int64`. Note that this does not match how the data - is represented internally and data copy is required to construct - the Series. + of type `Int64`. Note that this does not match how the data + is represented internally and data copy is required to construct + the Series. validity Validity buffer. If specified, must be a Series of data type `Boolean`. @@ -1004,19 +1017,7 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self: return self._from_pyseries(getattr(self._s, op_s)(_s)) if isinstance(other, (PyDecimal, int)) and self.dtype.is_decimal(): - # Infer the number's scale. Then use the max of the inferred scale and the - # Series' scale. At present, this will cause arithmetic to fail with a - # PyDecimal that has a scale greater than the Series' scale, but will ensure - # that scale is not lost. _s = sequence_to_pyseries(self.name, [other], dtype=Decimal) - _s = _s.cast( - Decimal( - scale=max( - cast(Decimal, _s.dtype()).scale, cast(Decimal, self.dtype).scale - ) - ), - strict=True, - ) if "rhs" in op_ffi: return self._from_pyseries(getattr(_s, op_s)(self._s)) @@ -1394,15 +1395,16 @@ def __setitem__( msg = f'cannot use "{key!r}" for indexing' raise TypeError(msg) - def __array__(self, dtype: Any = None) -> np.ndarray[Any, Any]: + def __array__(self, dtype: Any | None = None) -> np.ndarray[Any, Any]: """ Numpy __array__ interface protocol. Ensures that `np.asarray(pl.Series(..))` works as expected, see https://numpy.org/devdocs/user/basics.interoperability.html#the-array-method. """ - if not dtype and self.dtype == String and not self.null_count(): + if dtype is None and self.null_count() == 0 and self.dtype == String: dtype = np.dtype("U") + if dtype: return self.to_numpy().__array__(dtype) else: @@ -1430,7 +1432,7 @@ def __array_ufunc__( args.append(arg) elif isinstance(arg, Series): validity_mask &= arg.is_not_null() - args.append(arg._view(ignore_nulls=True)) + args.append(arg.to_physical()._s.to_numpy_view()) else: msg = f"unsupported type {type(arg).__name__!r} for {arg!r}" raise TypeError(msg) @@ -1485,22 +1487,9 @@ def __array_ufunc__( ) raise NotImplementedError(msg) - def __column_consortium_standard__(self, *, api_version: str | None = None) -> Any: - """ - Provide entry point to the Consortium DataFrame Standard API. - - This is developed and maintained outside of polars. - Please report any issues to https://github.com/data-apis/dataframe-api-compat. - """ - return ( - dataframe_api_compat.polars_standard.convert_to_standard_compliant_column( - self, api_version=api_version - ) - ) - def _repr_html_(self) -> str: """Format output data in HTML for display in Jupyter Notebooks.""" - return self.to_frame()._repr_html_(from_series=True) + return self.to_frame()._repr_html_(_from_series=True) @deprecate_renamed_parameter("row", "index", version="0.19.3") def item(self, index: int | None = None) -> Any: @@ -1683,7 +1672,7 @@ def all(self, *, ignore_nulls: bool = True) -> bool | None: Ignore null values (default). If set to `False`, `Kleene logic`_ is used to deal with nulls: - if the column contains any null values and no `True` values, + if the column contains any null values and no `False` values, the output is `None`. .. _Kleene logic: https://en.wikipedia.org/wiki/Three-valued_logic @@ -2075,7 +2064,7 @@ def nan_min(self) -> int | float | date | datetime | timedelta | str: """ return self.to_frame().select_seq(F.col(self.name).nan_min()).item() - def std(self, ddof: int = 1) -> float | None: + def std(self, ddof: int = 1) -> float | timedelta | None: """ Get the standard deviation of this Series. @@ -2092,11 +2081,9 @@ def std(self, ddof: int = 1) -> float | None: >>> s.std() 1.0 """ - if not self.dtype.is_numeric(): - return None return self._s.std(ddof) - def var(self, ddof: int = 1) -> float | None: + def var(self, ddof: int = 1) -> float | timedelta | None: """ Get variance of this Series. @@ -2113,8 +2100,6 @@ def var(self, ddof: int = 1) -> float | None: >>> s.var() 1.0 """ - if not self.dtype.is_numeric(): - return None return self._s.var(ddof) def median(self) -> PythonLiteral | None: @@ -3727,6 +3712,7 @@ def not_(self) -> Series: true ] """ + return self._from_pyseries(self._s.not_()) def is_null(self) -> Series: """ @@ -4156,12 +4142,18 @@ def to_physical(self) -> Series: def to_list(self, *, use_pyarrow: bool | None = None) -> list[Any]: """ - Convert this Series to a Python List. This operation clones data. + Convert this Series to a Python list. + + This operation copies data. Parameters ---------- use_pyarrow - Use pyarrow for the conversion. + Use PyArrow to perform the conversion. + + .. deprecated:: 0.19.9 + This parameter will be removed. The function can safely be called + without the parameter - it should give the exact same result. Examples -------- @@ -4288,40 +4280,44 @@ def is_between( def to_numpy( self, - *args: Any, - zero_copy_only: bool = False, + *, + allow_copy: bool = True, writable: bool = False, use_pyarrow: bool = True, + zero_copy_only: bool | None = None, ) -> np.ndarray[Any, Any]: """ - Convert this Series to numpy. + Convert this Series to a NumPy ndarray. - This operation may clone data but is completely safe. Note that: + This operation may copy data, but is completely safe. Note that: - - data which is purely numeric AND without null values is not cloned; - - floating point `nan` values can be zero-copied; - - booleans can't be zero-copied. + - Data which is purely numeric AND without null values is not cloned + - Floating point `nan` values can be zero-copied + - Booleans cannot be zero-copied - To ensure that no data is cloned, set `zero_copy_only=True`. + To ensure that no data is copied, set `allow_copy=False`. Parameters ---------- - *args - args will be sent to pyarrow.Array.to_numpy. - zero_copy_only - If True, an exception will be raised if the conversion to a numpy - array would require copying the underlying data (e.g. in presence - of nulls, or for non-primitive types). + allow_copy + Allow memory to be copied to perform the conversion. If set to `False`, + causes conversions that are not zero-copy to fail. writable - For numpy arrays created with zero copy (view on the Arrow data), - the resulting array is not writable (Arrow data is immutable). - By setting this to True, a copy of the array is made to ensure - it is writable. + Ensure the resulting array is writable. This will force a copy of the data + if the array was created without copy, as the underlying Arrow data is + immutable. use_pyarrow Use `pyarrow.Array.to_numpy `_ + for the conversion to NumPy. + zero_copy_only + Raise an exception if the conversion to a NumPy would require copying + the underlying data. Data copy occurs, for example, when the Series contains + nulls or non-numeric types. - for the conversion to numpy. + .. deprecated:: 0.20.10 + Use the `allow_copy` parameter instead, which is the inverse of this + one. Examples -------- @@ -4332,24 +4328,39 @@ def to_numpy( >>> type(arr) """ + if zero_copy_only is not None: + issue_deprecation_warning( + "The `zero_copy_only` parameter for `Series.to_numpy` is deprecated." + " Use the `allow_copy` parameter instead, which is the inverse of `zero_copy_only`.", + version="0.20.10", + ) + allow_copy = not zero_copy_only - def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: - if self.dtype == Date: - tp = "datetime64[D]" - elif self.dtype == Duration: - tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[attr-defined] - else: - tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[attr-defined] - return arr.astype(tp) - - def raise_no_zero_copy() -> None: - if zero_copy_only: + def raise_on_copy() -> None: + if not allow_copy and not self.is_empty(): msg = "cannot return a zero-copy array" raise ValueError(msg) - if self.dtype == Array: + def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: + if dtype == Date: + return np.dtype("datetime64[D]") + elif dtype == Duration: + return np.dtype(f"timedelta64[{dtype.time_unit}]") # type: ignore[union-attr] + elif dtype == Datetime: + return np.dtype(f"datetime64[{dtype.time_unit}]") # type: ignore[union-attr] + else: + msg = f"invalid temporal type: {dtype}" + raise TypeError(msg) + + if self.n_chunks() > 1: + raise_on_copy() + self = self.rechunk() + + dtype = self.dtype + + if dtype == Array: np_array = self.explode().to_numpy( - zero_copy_only=zero_copy_only, + allow_copy=allow_copy, writable=writable, use_pyarrow=use_pyarrow, ) @@ -4359,75 +4370,48 @@ def raise_no_zero_copy() -> None: if ( use_pyarrow and _PYARROW_AVAILABLE - and self.dtype != Object - and (self.dtype == Time or not self.dtype.is_temporal()) + and dtype not in (Object, Datetime, Duration, Date) ): return self.to_arrow().to_numpy( - *args, zero_copy_only=zero_copy_only, writable=writable + zero_copy_only=not allow_copy, writable=writable ) - elif self.dtype in (Time, Decimal): - raise_no_zero_copy() - # note: there are no native numpy "time" or "decimal" dtypes - return np.array(self.to_list(), dtype="object") - else: - if not self.null_count(): - if self.dtype.is_temporal(): - np_array = convert_to_date(self._view(ignore_nulls=True)) - elif self.dtype.is_numeric(): - np_array = self._view(ignore_nulls=True) - else: - raise_no_zero_copy() - np_array = self._s.to_numpy() - - elif self.dtype.is_temporal(): - np_array = convert_to_date(self.to_physical()._s.to_numpy()) + if self.null_count() == 0: + if dtype.is_integer() or dtype.is_float(): + np_array = self._s.to_numpy_view() + elif dtype == Boolean: + raise_on_copy() + s_u8 = self.cast(UInt8) + np_array = s_u8._s.to_numpy_view().view(bool) + elif dtype in (Datetime, Duration): + np_dtype = temporal_dtype_to_numpy(dtype) + s_i64 = self.to_physical() + np_array = s_i64._s.to_numpy_view().view(np_dtype) + elif dtype == Date: + raise_on_copy() + np_dtype = temporal_dtype_to_numpy(dtype) + s_i32 = self.to_physical() + np_array = s_i32._s.to_numpy_view().astype(np_dtype) else: - raise_no_zero_copy() + raise_on_copy() np_array = self._s.to_numpy() - if writable and not np_array.flags.writeable: - raise_no_zero_copy() - return np_array.copy() - else: - return np_array - - def _view(self, *, ignore_nulls: bool = False) -> SeriesView: - """ - Get a view into this Series data with a numpy array. - - This operation doesn't clone data, but does not include missing values. - - Returns - ------- - SeriesView - - Parameters - ---------- - ignore_nulls - If True then nulls are converted to 0. - If False then an Exception is raised if nulls are present. - - Examples - -------- - >>> s = pl.Series("a", [1, None]) - >>> s._view(ignore_nulls=True) - SeriesView([1, 0]) - """ - if not ignore_nulls: - assert not self.null_count() + else: + raise_on_copy() + np_array = self._s.to_numpy() + if dtype in (Datetime, Duration, Date): + np_dtype = temporal_dtype_to_numpy(dtype) + np_array = np_array.view(np_dtype) - from polars.series._numpy import SeriesView, _ptr_to_numpy + if writable and not np_array.flags.writeable: + raise_on_copy() + np_array = np_array.copy() - ptr_type = dtype_to_ctype(self.dtype) - ptr = self._s.as_single_ptr() - array = _ptr_to_numpy(ptr, self.len(), ptr_type) - array.setflags(write=False) - return SeriesView(array, self) + return np_array def to_arrow(self) -> pa.Array: """ - Get the underlying Arrow Array. + Return the underlying Arrow array. If the Series contains only a single chunk this operation is zero copy. @@ -4501,7 +4485,7 @@ def to_pandas( Name: b, dtype: int64[pyarrow] """ if self.dtype == Object: - # Can't convert via PyArrow, so do it via NumPy: + # Can't convert via PyArrow, so do it via NumPy return pd.Series(self.to_numpy(), dtype=object, name=self.name) if use_pyarrow_extension_array: @@ -4517,6 +4501,10 @@ def to_pandas( ) pa_arr = self.to_arrow() + # pandas does not support unsigned dictionary indices + if pa.types.is_dictionary(pa_arr.type): + pa_arr = pa_arr.cast(pa.dictionary(pa.int64(), pa.large_string())) + if use_pyarrow_extension_array: pd_series = pa_arr.to_pandas( self_destruct=True, @@ -5266,8 +5254,9 @@ def map_elements( function Custom function or lambda. return_dtype - Output datatype. If none is given, the same datatype as this Series will be - used. + Output datatype. + If not set, the dtype will be inferred based on the first non-null value + that is returned by the function. skip_nulls Nulls will be skipped and not passed to the python function. This is faster because python can be skipped and because we call @@ -6473,6 +6462,16 @@ def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> float | None: Pearson's definition is used (normal ==> 3.0). bias : bool, optional If False, the calculations are corrected for statistical bias. + + Examples + -------- + >>> s = pl.Series("grades", [66, 79, 54, 97, 96, 70, 69, 85, 93, 75]) + >>> s.kurtosis() + -1.0522623626787952 + >>> s.kurtosis(fisher=False) + 1.9477376373212048 + >>> s.kurtosis(fisher=False, bias=False) + 2.104036180264273 """ return self._s.kurtosis(fisher, bias) @@ -6781,7 +6780,7 @@ def ewm_mean( *, adjust: bool = True, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving average. @@ -6810,7 +6809,7 @@ def ewm_mean( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6824,7 +6823,7 @@ def ewm_mean( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -6832,7 +6831,7 @@ def ewm_mean( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -6842,7 +6841,7 @@ def ewm_mean( Examples -------- >>> s = pl.Series([1, 2, 3]) - >>> s.ewm_mean(com=1) + >>> s.ewm_mean(com=1, ignore_nulls=False) shape: (3,) Series: '' [f64] [ @@ -6863,7 +6862,7 @@ def ewm_std( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving standard deviation. @@ -6892,7 +6891,7 @@ def ewm_std( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6909,7 +6908,7 @@ def ewm_std( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -6917,7 +6916,7 @@ def ewm_std( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -6927,7 +6926,7 @@ def ewm_std( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.ewm_std(com=1) + >>> s.ewm_std(com=1, ignore_nulls=False) shape: (3,) Series: 'a' [f64] [ @@ -6948,7 +6947,7 @@ def ewm_var( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving variance. @@ -6977,7 +6976,7 @@ def ewm_var( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6994,7 +6993,7 @@ def ewm_var( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -7002,7 +7001,7 @@ def ewm_var( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -7012,7 +7011,7 @@ def ewm_var( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.ewm_var(com=1) + >>> s.ewm_var(com=1, ignore_nulls=False) shape: (3,) Series: 'a' [f64] [ @@ -7022,15 +7021,15 @@ def ewm_var( ] """ - def extend_constant(self, value: PythonLiteral | None, n: int) -> Series: + def extend_constant(self, value: IntoExpr, n: int | IntoExprColumn) -> Series: """ Extremely fast method for extending the Series with 'n' copies of a value. Parameters ---------- value - A constant literal value (not an expression) with which to extend - the Series; can pass None to extend with nulls. + A constant literal value or a unit expressioin with which to extend the + expression result Series; can pass None to extend with nulls. n The number of additional values that will be added. @@ -7518,7 +7517,7 @@ def cumprod(self, *, reverse: bool = False) -> Series: return self.cum_prod(reverse=reverse) @deprecate_function( - "Use `Series.to_numpy(zero_copy_only=True) instead.", version="0.19.14" + "Use `Series.to_numpy(allow_copy=False) instead.", version="0.19.14" ) def view(self, *, ignore_nulls: bool = False) -> SeriesView: """ @@ -7536,7 +7535,16 @@ def view(self, *, ignore_nulls: bool = False) -> SeriesView: If True then nulls are converted to 0. If False then an Exception is raised if nulls are present. """ - return self._view(ignore_nulls=ignore_nulls) + if not ignore_nulls: + assert not self.null_count() + + from polars.series._numpy import SeriesView, _ptr_to_numpy + + ptr_type = dtype_to_ctype(self.dtype) + ptr = self._s.as_single_ptr() + array = _ptr_to_numpy(ptr, self.len(), ptr_type) + array.setflags(write=False) + return SeriesView(array, self) @deprecate_function( "It has been renamed to `replace`." diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index af4907c9a7d8..05e385dc5fad 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from polars.datatypes.constants import N_INFER_DEFAULT from polars.series.utils import expr_dispatch from polars.utils.deprecation import ( deprecate_renamed_function, @@ -371,6 +372,12 @@ def len_chars(self) -> Series: equivalent output with much better performance: :func:`len_bytes` runs in _O(1)_, while :func:`len_chars` runs in (_O(n)_). + A character is defined as a `Unicode scalar value`_. A single character is + represented by a single byte when working with ASCII text, and a maximum of + 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + Examples -------- >>> s = pl.Series(["Café", "345", "東京", None]) @@ -628,8 +635,8 @@ def starts_with(self, prefix: str | Expr) -> Series: """ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -638,6 +645,23 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Series + Series of data type :class:`Binary`. + + Examples + -------- + >>> s = pl.Series("color", ["000000", "ffff00", "0000ff"]) + >>> s.str.decode("hex") + shape: (3,) + Series: 'color' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + b"\x00\x00\xff" + ] """ def encode(self, encoding: TransferEncoding) -> Series: @@ -668,7 +692,9 @@ def encode(self, encoding: TransferEncoding) -> Series: """ def json_decode( - self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + self, + dtype: PolarsDataType | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Series: """ Parse string values as JSON. @@ -681,8 +707,8 @@ def json_decode( The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. infer_schema_length - How many rows to parse to determine the schema. - If `None` all rows are used. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. See Also -------- @@ -1124,56 +1150,71 @@ def replace( value String that will replace the matched substring. literal - Treat pattern as a literal string. + Treat `pattern` as a literal string. n Number of matches to replace. + See Also + -------- + replace_all + Notes ----- - To modify regular expression behaviour (such as case-sensitivity) with flags, - use the inline `(?iLmsuxU)` syntax. For example: + The dollar sign (`$`) is a special character related to capture groups. + To refer to a literal dollar sign, use `$$` instead or set `literal` to `True`. - >>> s = pl.Series( - ... name="weather", - ... values=[ - ... "Foggy", - ... "Rainy", - ... "Sunny", - ... ], - ... ) - >>> # apply case-insensitive string replacement - >>> s.str.replace(r"(?i)foggy|rainy", "Sunny") - shape: (3,) - Series: 'weather' [str] - [ - "Sunny" - "Sunny" - "Sunny" - ] - - See the regex crate's section on `grouping and flags - `_ for - additional information about the use of inline expression modifiers. - - See Also - -------- - replace_all : Replace all matching regex/literal substrings. + To modify regular expression behaviour (such as case-sensitivity) with flags, + use the inline `(?iLmsuxU)` syntax. See the regex crate's section on + `grouping and flags `_ + for additional information about the use of inline expression modifiers. Examples -------- >>> s = pl.Series(["123abc", "abc456"]) - >>> s.str.replace(r"abc\b", "ABC") # doctest: +IGNORE_RESULT + >>> s.str.replace(r"abc\b", "ABC") shape: (2,) Series: '' [str] [ "123ABC" "abc456" ] + + Capture groups are supported. Use `${1}` in the `value` string to refer to the + first capture group in the `pattern`, `${2}` to refer to the second capture + group, and so on. You can also use named capture groups. + + >>> s = pl.Series(["hat", "hut"]) + >>> s.str.replace("h(.)t", "b${1}d") + shape: (2,) + Series: '' [str] + [ + "bad" + "bud" + ] + >>> s.str.replace("h(?.)t", "b${vowel}d") + shape: (2,) + Series: '' [str] + [ + "bad" + "bud" + ] + + Apply case-insensitive string replacement using the `(?i)` flag. + + >>> s = pl.Series("weather", ["Foggy", "Rainy", "Sunny"]) + >>> s.str.replace(r"(?i)foggy|rainy", "Sunny") + shape: (3,) + Series: 'weather' [str] + [ + "Sunny" + "Sunny" + "Sunny" + ] """ def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Series: - """ - Replace all matching regex/literal substrings with a new string value. + r""" + Replace first matching regex/literal substring with a new string value. Parameters ---------- @@ -1181,23 +1222,67 @@ def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Ser A valid regular expression pattern, compatible with the `regex crate `_. value - String that will replace the matches. + String that will replace the matched substring. literal - Treat pattern as a literal string. + Treat `pattern` as a literal string. + n + Number of matches to replace. See Also -------- - replace : Replace first matching regex/literal substring. + replace_all + + Notes + ----- + The dollar sign (`$`) is a special character related to capture groups. + To refer to a literal dollar sign, use `$$` instead or set `literal` to `True`. + + To modify regular expression behaviour (such as case-sensitivity) with flags, + use the inline `(?iLmsuxU)` syntax. See the regex crate's section on + `grouping and flags `_ + for additional information about the use of inline expression modifiers. Examples -------- - >>> df = pl.Series(["abcabc", "123a123"]) - >>> df.str.replace_all("a", "-") + >>> s = pl.Series(["123abc", "abc456"]) + >>> s.str.replace_all(r"abc\b", "ABC") shape: (2,) Series: '' [str] [ - "-bc-bc" - "123-123" + "123ABC" + "abc456" + ] + + Capture groups are supported. Use `${1}` in the `value` string to refer to the + first capture group in the `pattern`, `${2}` to refer to the second capture + group, and so on. You can also use named capture groups. + + >>> s = pl.Series(["hat", "hut"]) + >>> s.str.replace_all("h(.)t", "b${1}d") + shape: (2,) + Series: '' [str] + [ + "bad" + "bud" + ] + >>> s.str.replace_all("h(?.)t", "b${vowel}d") + shape: (2,) + Series: '' [str] + [ + "bad" + "bud" + ] + + Apply case-insensitive string replacement using the `(?i)` flag. + + >>> s = pl.Series("weather", ["Foggy", "Rainy", "Sunny"]) + >>> s.str.replace_all(r"(?i)foggy|rainy", "Sunny") + shape: (3,) + Series: 'weather' [str] + [ + "Sunny" + "Sunny" + "Sunny" ] """ @@ -1521,7 +1606,7 @@ def slice( self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None ) -> Series: """ - Create subslices of the string values of a String Series. + Extract a substring from each string value. Parameters ---------- @@ -1534,15 +1619,23 @@ def slice( Returns ------- Series - Series of data type :class:`Struct` with fields of data type - :class:`String`. + Series of data type :class:`String`. + + Notes + ----- + Both the `offset` and `length` inputs are defined in terms of the number + of characters in the (UTF8) string. A character is defined as a + `Unicode scalar value`_. A single character is represented by a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value Examples -------- - >>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"]) + >>> s = pl.Series(["pear", None, "papaya", "dragonfruit"]) >>> s.str.slice(-3) shape: (4,) - Series: 's' [str] + Series: '' [str] [ "ear" null @@ -1554,7 +1647,7 @@ def slice( >>> s.str.slice(4, length=3) shape: (4,) - Series: 's' [str] + Series: '' [str] [ "" null @@ -1773,7 +1866,9 @@ def rjust(self, length: int, fill_char: str = " ") -> Series: @deprecate_renamed_function("json_decode", version="0.19.15") def json_extract( - self, dtype: PolarsDataType | None = None, infer_schema_length: int | None = 100 + self, + dtype: PolarsDataType | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Series: """ Parse string values as JSON. @@ -1787,8 +1882,8 @@ def json_extract( The dtype to cast the extracted value to. If None, the dtype will be inferred from the JSON value. infer_schema_length - How many rows to parse to determine the schema. - If `None` all rows are used. + The maximum number of rows to scan for schema inference. + If set to `None`, the full data may be scanned *(this is slow)*. """ return self.json_decode(dtype, infer_schema_length) diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index dbd3faaf9bd0..b0fe9f4e22b9 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -1,17 +1,16 @@ from __future__ import annotations -import os from collections import OrderedDict from typing import TYPE_CHECKING, Sequence from polars.series.utils import expr_dispatch from polars.utils._wrap import wrap_df -from polars.utils.various import sphinx_accessor +from polars.utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor if TYPE_CHECKING: from polars import DataFrame, DataType, Series from polars.polars import PySeries -elif os.getenv("BUILDING_SPHINX_DOCS"): +elif BUILDING_SPHINX_DOCS: property = sphinx_accessor diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 40e52fe0df55..ff2f8fc04c39 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -85,6 +85,8 @@ def assert_frame_equal( ... AssertionError: values for column 'a' are different """ + __tracebackhide__ = True + lazy = _assert_correct_input_type(left, right) objects = "LazyFrames" if lazy else "DataFrames" @@ -132,6 +134,8 @@ def assert_frame_equal( def _assert_correct_input_type( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame ) -> bool: + __tracebackhide__ = True + if isinstance(left, DataFrame) and isinstance(right, DataFrame): return False elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame): @@ -153,6 +157,8 @@ def _assert_frame_schema_equal( check_column_order: bool, objects: str, ) -> None: + __tracebackhide__ = True + left_schema, right_schema = left.schema, right.schema # Fast path for equal frames @@ -253,6 +259,8 @@ def assert_frame_not_equal( ... AssertionError: frames are equal """ + __tracebackhide__ = True + try: assert_frame_equal( left=left, diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 4a4b5b16d5a1..5bf691037ea9 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -6,8 +6,6 @@ FLOAT_DTYPES, Array, Categorical, - Decimal, - Float64, List, String, Struct, @@ -85,6 +83,8 @@ def assert_series_equal( [left]: [1, 2, 3] [right]: [1, 5, 3] """ + __tracebackhide__ = True + if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] raise_assertion_error( "inputs", @@ -121,6 +121,8 @@ def _assert_series_values_equal( atol: float, categorical_as_str: bool, ) -> None: + __tracebackhide__ = True + """Assert that the values in both Series are equal.""" # Handle categoricals if categorical_as_str: @@ -129,14 +131,6 @@ def _assert_series_values_equal( if right.dtype == Categorical: right = right.cast(String) - # Handle decimals - # TODO: Delete this branch when Decimal equality is implemented - # https://github.com/pola-rs/polars/issues/12118 - if left.dtype == Decimal: - left = left.cast(Float64) - if right.dtype == Decimal: - right = right.cast(Float64) - # Determine unequal elements try: unequal = left.ne_missing(right) @@ -201,6 +195,8 @@ def _assert_series_nested_values_equal( atol: float, categorical_as_str: bool, ) -> None: + __tracebackhide__ = True + # compare nested lists element-wise if _comparing_lists(left.dtype, right.dtype): for s1, s2 in zip(left, right): @@ -231,6 +227,7 @@ def _assert_series_nested_values_equal( def _assert_series_null_values_match(left: Series, right: Series) -> None: + __tracebackhide__ = True null_value_mismatch = left.is_null() != right.is_null() if null_value_mismatch.any(): raise_assertion_error( @@ -239,6 +236,7 @@ def _assert_series_null_values_match(left: Series, right: Series) -> None: def _assert_series_nan_values_match(left: Series, right: Series) -> None: + __tracebackhide__ = True if not _comparing_floats(left.dtype, right.dtype): return nan_value_mismatch = left.is_nan() != right.is_nan() @@ -280,6 +278,8 @@ def _assert_series_values_within_tolerance( rtol: float, atol: float, ) -> None: + __tracebackhide__ = True + left_unequal, right_unequal = left.filter(unequal), right.filter(unequal) difference = (left_unequal - right_unequal).abs() @@ -349,6 +349,8 @@ def assert_series_not_equal( ... AssertionError: Series are equal """ + __tracebackhide__ = True + try: assert_series_equal( left=left, diff --git a/py-polars/polars/testing/parametric/__init__.py b/py-polars/polars/testing/parametric/__init__.py index 892cb27be2db..862b0b0d923a 100644 --- a/py-polars/polars/testing/parametric/__init__.py +++ b/py-polars/polars/testing/parametric/__init__.py @@ -7,6 +7,7 @@ from polars.testing.parametric.profiles import load_profile, set_profile from polars.testing.parametric.strategies import ( all_strategies, + create_array_strategy, create_list_strategy, nested_strategies, scalar_strategies, @@ -22,6 +23,7 @@ def __getattr__(*args: Any, **kwargs: Any) -> Any: "all_strategies", "column", "columns", + "create_array_strategy", "create_list_strategy", "dataframes", "load_profile", diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index b87f1541c52f..2705723965c7 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -14,6 +14,7 @@ from polars.dataframe import DataFrame from polars.datatypes import ( DTYPE_TEMPORAL_UNITS, + Array, Categorical, DataType, DataTypeClass, @@ -29,6 +30,7 @@ _flexhash, all_strategies, between, + create_array_strategy, create_list_strategy, scalar_strategies, ) @@ -41,7 +43,6 @@ from polars import LazyFrame from polars.type_aliases import OneOrMoreDataTypes, PolarsDataType - _time_units = list(DTYPE_TEMPORAL_UNITS) @@ -122,11 +123,19 @@ def __post_init__(self) -> None: if self.dtype is None and self.strategy is None: self.dtype = random.choice(strategy_dtypes) - elif self.dtype == List: + elif self.dtype in (Array, List): if self.strategy is not None: self.dtype = getattr(self.strategy, "_dtype", self.dtype) else: - self.strategy = create_list_strategy(getattr(self.dtype, "inner", None)) + if self.dtype == Array: + self.strategy = create_array_strategy( + getattr(self.dtype, "inner", None), + getattr(self.dtype, "width", None), + ) + else: + self.strategy = create_list_strategy( + getattr(self.dtype, "inner", None) + ) self.dtype = self.strategy._dtype # type: ignore[attr-defined] # elif self.dtype == Struct: @@ -162,7 +171,7 @@ def __post_init__(self) -> None: if sample_value_type is not None: value_dtype = py_type_to_dtype(sample_value_type) - if value_dtype is not List: + if value_dtype is not Array and value_dtype is not List: self.dtype = value_dtype @@ -380,7 +389,7 @@ def draw_series(draw: DrawFn) -> Series: else: dtype_strategy = strategy - if series_dtype.is_float() and not allow_infinities: + if not allow_infinities and series_dtype.is_float(): dtype_strategy = dtype_strategy.filter( lambda x: not isinstance(x, float) or isfinite(x) ) @@ -676,7 +685,7 @@ def draw_frames(draw: DrawFn) -> DataFrame | LazyFrame: # note: randomly change between column-wise and row-wise frame init orient = "col" - if draw(booleans()) and not any(c.dtype == List for c in coldefs): + if draw(booleans()) and not any(c.dtype in (Array, List) for c in coldefs): data = list(zip(*data.values())) # type: ignore[assignment] orient = "row" @@ -707,7 +716,7 @@ def draw_frames(draw: DrawFn) -> DataFrame | LazyFrame: # failed frame init: reproduce with... pl.DataFrame( data={frame_data}, - schema={repr(schema).replace("', ","', pl.")}, + schema={repr(schema).replace("', ", "', pl.")}, orient={orient!r}, ) """.replace("datetime.", "") diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index 74e26a11ebce..c014a004e66f 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -3,7 +3,7 @@ import os from datetime import datetime, timedelta from itertools import chain -from random import choice, shuffle +from random import choice, randint, shuffle from string import ascii_uppercase from typing import ( TYPE_CHECKING, @@ -36,6 +36,7 @@ ) from polars.datatypes import ( + Array, Binary, Boolean, Categorical, @@ -59,7 +60,6 @@ is_polars_dtype, ) from polars.type_aliases import PolarsDataType -from polars.utils.deprecation import deprecate_nonkeyword_arguments if TYPE_CHECKING: import sys @@ -315,14 +315,68 @@ def _flexhash(elem: Any) -> int: return hash(elem) -@deprecate_nonkeyword_arguments(allowed_args=["inner_dtype"], version="0.19.3") +def create_array_strategy( + inner_dtype: PolarsDataType | None = None, + width: int | None = None, + *, + select_from: Sequence[Any] | None = None, + unique: bool = False, +) -> SearchStrategy[list[Any]]: + """ + Hypothesis strategy for producing polars Array data. + + Parameters + ---------- + inner_dtype : PolarsDataType + type of the inner array elements (can also be another Array). + width : int, optional + generated arrays will have this length. + select_from : list, optional + randomly select the innermost values from this list (otherwise + the default strategy associated with the innermost dtype is used). + unique : bool, optional + ensure that the generated lists contain unique values. + + Examples + -------- + Create a strategy that generates arrays of i32 values: + + >>> arr = create_array_strategy(inner_dtype=pl.Int32, width=3) + >>> arr.example() # doctest: +SKIP + [-11330, 24030, 116] + + Create a strategy that generates arrays of specific strings: + + >>> arr = create_array_strategy(inner_dtype=pl.String, width=2) + >>> arr.example() # doctest: +SKIP + ['xx', 'yy'] + """ + if width is None: + width = randint(a=1, b=8) + + if inner_dtype is None: + strats = list(_get_strategy_dtypes(base_type=True)) + shuffle(strats) + inner_dtype = choice(strats) + + strat = create_list_strategy( + inner_dtype=inner_dtype, + select_from=select_from, + size=width, + unique=unique, + ) + strat._dtype = Array(inner_dtype, width=width) # type: ignore[attr-defined] + return strat + + def create_list_strategy( - inner_dtype: PolarsDataType | None, + inner_dtype: PolarsDataType | None = None, + *, select_from: Sequence[Any] | None = None, size: int | None = None, min_size: int | None = None, max_size: int | None = None, - unique: bool = False, # noqa: FBT001 + unique: bool = False, ) -> SearchStrategy[list[Any]]: """ Hypothesis strategy for producing polars List data. @@ -391,13 +445,23 @@ def create_list_strategy( if max_size is None: max_size = 3 if not min_size else (min_size * 2) - if inner_dtype == List: - st = create_list_strategy( - inner_dtype=inner_dtype.inner, # type: ignore[union-attr] - select_from=select_from, - min_size=min_size, - max_size=max_size, - ) + if inner_dtype in (Array, List): + if inner_dtype == Array: + if (width := getattr(inner_dtype, "width", None)) is None: + width = randint(a=1, b=8) + st = create_array_strategy( + inner_dtype=inner_dtype.inner, # type: ignore[union-attr] + select_from=select_from, + width=width, + ) + else: + st = create_list_strategy( + inner_dtype=inner_dtype.inner, # type: ignore[union-attr] + select_from=select_from, + min_size=min_size, + max_size=max_size, + ) + if inner_dtype.inner is None and hasattr(st, "_dtype"): # type: ignore[union-attr] inner_dtype = st._dtype else: @@ -421,6 +485,7 @@ def create_list_strategy( # def create_struct_strategy( +nested_strategies[Array] = create_array_strategy nested_strategies[List] = create_list_strategy # nested_strategies[Struct] = create_struct_strategy(inner_dtype=None) diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index eee5e670d8a6..ea1a8a0cf6c7 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -23,6 +23,7 @@ import sys from sqlalchemy import Engine + from sqlalchemy.orm import Session from polars import DataFrame, Expr, LazyFrame, Series from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType @@ -250,4 +251,4 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: """Fetch results in batches.""" -ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine"] +ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"] diff --git a/py-polars/polars/utils/__init__.py b/py-polars/polars/utils/__init__.py index 3756418259df..133bca13981b 100644 --- a/py-polars/polars/utils/__init__.py +++ b/py-polars/polars/utils/__init__.py @@ -4,11 +4,10 @@ Functions that are part of the public API are re-exported here. """ from polars.utils._scan import _execute_from_rust -from polars.utils.build_info import build_info from polars.utils.convert import ( _date_to_pl_date, - _datetime_for_anyvalue, - _datetime_for_anyvalue_windows, + _datetime_for_any_value, + _datetime_for_any_value_windows, _time_to_pl_time, _timedelta_to_pl_timedelta, _to_python_date, @@ -17,22 +16,16 @@ _to_python_time, _to_python_timedelta, ) -from polars.utils.meta import get_index_type, threadpool_size -from polars.utils.show_versions import show_versions from polars.utils.various import NoDefault, _polars_warn, is_column, no_default __all__ = [ "NoDefault", - "build_info", - "get_index_type", "is_column", "no_default", - "show_versions", - "threadpool_size", # Required for Rust bindings "_date_to_pl_date", - "_datetime_for_anyvalue", - "_datetime_for_anyvalue_windows", + "_datetime_for_any_value", + "_datetime_for_any_value_windows", "_execute_from_rust", "_polars_warn", "_time_to_pl_time", diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 0501dfb6e5d2..9f3561ea91a7 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -63,9 +63,14 @@ from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa -from polars.exceptions import ComputeError, ShapeError, TimeZoneAwareConstructorWarning +from polars.exceptions import ( + ComputeError, + SchemaError, + ShapeError, + TimeZoneAwareConstructorWarning, +) +from polars.meta import get_index_type, thread_pool_size from polars.utils._wrap import wrap_df, wrap_s -from polars.utils.meta import get_index_type, threadpool_size from polars.utils.various import ( _is_generator, arrlen, @@ -226,7 +231,7 @@ def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> P else: if array.num_chunks > 1: # somehow going through ffi with a structarray - # returns the first chunk everytime + # returns the first chunk every time if isinstance(array.type, pa.StructType): pys = PySeries.from_arrow(name, array.combine_chunks()) else: @@ -235,7 +240,7 @@ def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> P for a in it: pys.append(PySeries.from_arrow(name, a)) elif array.num_chunks == 0: - pys = PySeries.from_arrow(name, pa.array([], array.type)) + pys = PySeries.from_arrow(name, pa.nulls(0, type=array.type)) else: pys = PySeries.from_arrow(name, array.chunks[0]) @@ -292,24 +297,27 @@ def _get_first_non_none(values: Sequence[Any | None]) -> Any: return next((v for v in values if v is not None), None) -def sequence_from_anyvalue_or_object(name: str, values: Sequence[Any]) -> PySeries: +def sequence_from_any_value_or_object(name: str, values: Sequence[Any]) -> PySeries: """ Last resort conversion. AnyValues are most flexible and if they fail we go for object types """ try: - return PySeries.new_from_anyvalues(name, values, strict=True) + return PySeries.new_from_any_values(name, values, strict=True) # raised if we cannot convert to Wrap except RuntimeError: return PySeries.new_object(name, values, _strict=False) + # raised if AnyValue fallbacks fail + except SchemaError: + return PySeries.new_object(name, values, _strict=False) except ComputeError as exc: if "mixed dtypes" in str(exc): return PySeries.new_object(name, values, _strict=False) raise -def sequence_from_anyvalue_and_dtype_or_object( +def sequence_from_any_value_and_dtype_or_object( name: str, values: Sequence[Any], dtype: PolarsDataType ) -> PySeries: """ @@ -318,7 +326,7 @@ def sequence_from_anyvalue_and_dtype_or_object( AnyValues are most flexible and if they fail we go for object types """ try: - return PySeries.new_from_anyvalues_and_dtype(name, values, dtype, strict=True) + return PySeries.new_from_any_values_and_dtype(name, values, dtype, strict=True) # raised if we cannot convert to Wrap except RuntimeError: return PySeries.new_object(name, values, _strict=False) @@ -449,7 +457,7 @@ def sequence_to_pyseries( dataclasses.is_dataclass(value) or is_pydantic_model(value) or is_namedtuple(value.__class__) - ): + ) and dtype != Object: return pl.DataFrame(values).to_struct(name)._s elif isinstance(value, range): values = [range_to_series("", v) for v in values] @@ -520,18 +528,24 @@ def sequence_to_pyseries( msg ) - # we use anyvalue builder to create the datetime array - # we store the values internally as UTC and set the timezone - py_series = PySeries.new_from_anyvalues(name, values, strict) + # We use the AnyValue builder to create the datetime array + # We store the values internally as UTC and set the timezone + py_series = PySeries.new_from_any_values(name, values, strict) + time_unit = getattr(dtype, "time_unit", None) + time_zone = getattr(dtype, "time_zone", None) + if time_unit is None or values_dtype == Date: s = wrap_s(py_series) else: s = wrap_s(py_series).dt.cast_time_unit(time_unit) - time_zone = getattr(dtype, "time_zone", None) if (values_dtype == Date) & (dtype == Datetime): - return s.cast(Datetime(time_unit)).dt.replace_time_zone(time_zone)._s + return ( + s.cast(Datetime(time_unit or "us")) + .dt.replace_time_zone(time_zone) + ._s + ) if (dtype == Datetime) and ( value.tzinfo is not None or time_zone is not None @@ -588,11 +602,11 @@ def sequence_to_pyseries( if isinstance(dtype, Object): return PySeries.new_object(name, values, strict) if dtype: - srs = sequence_from_anyvalue_and_dtype_or_object(name, values, dtype) - if not dtype.is_(srs.dtype()): + srs = sequence_from_any_value_and_dtype_or_object(name, values, dtype) + if dtype != srs.dtype(): srs = srs.cast(dtype, strict=False) return srs - return sequence_from_anyvalue_or_object(name, values) + return sequence_from_any_value_or_object(name, values) elif python_dtype == pl.Series: return PySeries.new_series_list(name, [v._s for v in values], strict) @@ -603,7 +617,7 @@ def sequence_to_pyseries( constructor = py_type_to_constructor(python_dtype) if constructor == PySeries.new_object: try: - srs = PySeries.new_from_anyvalues(name, values, strict) + srs = PySeries.new_from_any_values(name, values, strict) if _check_for_numpy(python_dtype, check_type=False) and isinstance( np.bool_(True), np.generic ): @@ -614,7 +628,7 @@ def sequence_to_pyseries( except RuntimeError: # raised if we cannot convert to Wrap - return sequence_from_anyvalue_or_object(name, values) + return sequence_from_any_value_or_object(name, values) return _construct_series_with_fallbacks( constructor, name, values, dtype, strict=strict @@ -665,7 +679,10 @@ def _pandas_series_to_arrow( def pandas_to_pyseries( - name: str, values: pd.Series[Any] | pd.DatetimeIndex, *, nan_to_null: bool = True + name: str, + values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, + *, + nan_to_null: bool = True, ) -> PySeries: """Construct a PySeries from a pandas Series or DatetimeIndex.""" # TODO: Change `if not name` to `if name is not None` once name is Optional[str] @@ -756,7 +773,6 @@ def _unpack_schema( schema_overrides: SchemaDict | None = None, n_expected: int | None = None, lookup_names: Iterable[str] | None = None, - include_overrides_in_columns: bool = False, ) -> tuple[list[str], SchemaDict]: """ Unpack column names and create dtype lookup. @@ -764,23 +780,33 @@ def _unpack_schema( Works for any (name, dtype) pairs or schema dict input, overriding any inferred dtypes with explicit dtypes if supplied. """ - # coerce schema_overrides to dict[str, PolarsDataType] - if schema_overrides: - schema_overrides = { - name: dtype - if is_polars_dtype(dtype, include_unknown=True) - else py_type_to_dtype(dtype) - for name, dtype in schema_overrides.items() + + def _normalize_dtype(dtype: Any) -> PolarsDataType: + """Parse non-Polars data types as Polars data types.""" + if is_polars_dtype(dtype, include_unknown=True): + return dtype + else: + return py_type_to_dtype(dtype) + + def _parse_schema_overrides( + schema_overrides: SchemaDict | None = None, + ) -> dict[str, PolarsDataType]: + """Parse schema overrides as a dictionary of name to Polars data type.""" + if schema_overrides is None: + return {} + + return { + name: _normalize_dtype(dtype) for name, dtype in schema_overrides.items() } - else: - schema_overrides = {} - # fastpath for empty schema + schema_overrides = _parse_schema_overrides(schema_overrides) + + # Fast path for empty schema if not schema: - return ( - [f"column_{i}" for i in range(n_expected)] if n_expected else [], - schema_overrides, + columns = ( + [f"column_{i}" for i in range(n_expected)] if n_expected is not None else [] ) + return columns, schema_overrides # determine column names from schema if isinstance(schema, Mapping): @@ -795,25 +821,32 @@ def _unpack_schema( # determine column dtypes from schema and lookup_names lookup: dict[str, str] | None = ( - {col: name for col, name in zip_longest(column_names, lookup_names) if name} + { + col: name + for col, name in zip_longest(column_names, lookup_names) + if name is not None + } if lookup_names else None ) - column_dtypes: dict[str, PolarsDataType] = { - lookup.get((name := col[0]), name) if lookup else col[0]: dtype # type: ignore[misc] - if is_polars_dtype(dtype, include_unknown=True) - else py_type_to_dtype(dtype) - for col in schema - if isinstance(col, tuple) and (dtype := col[1]) is not None - } + + column_dtypes: dict[str, PolarsDataType] = {} + for col in schema: + if isinstance(col, str): + continue + + name, dtype = col + if dtype is None: + continue + else: + dtype = _normalize_dtype(dtype) + name = lookup.get(name, name) if lookup else name + column_dtypes[name] = dtype # type: ignore[assignment] # apply schema overrides if schema_overrides: column_dtypes.update(schema_overrides) - if include_overrides_in_columns: - column_names.extend(col for col in column_dtypes if col not in column_names) - return column_names, column_dtypes @@ -875,9 +908,9 @@ def _expand_dict_scalars( elif val is None or isinstance( # type: ignore[redundant-expr] val, (int, float, str, bool, date, datetime, time, timedelta) ): - updated_data[name] = pl.Series( - name=name, values=[val], dtype=dtype - ).extend_constant(val, array_len - 1) + updated_data[name] = F.repeat( + val, array_len, dtype=dtype, eager=True + ).alias(name) else: updated_data[name] = pl.Series( name=name, values=[val] * array_len, dtype=dtype @@ -936,7 +969,7 @@ def dict_to_pydf( # (note: 'dummy' is threaded) import multiprocessing.dummy - pool_size = threadpool_size() + pool_size = thread_pool_size() with multiprocessing.dummy.Pool(pool_size) as pool: data = dict( zip( @@ -1060,7 +1093,7 @@ def _sequence_to_pydf_dispatcher( to_pydf = _sequence_of_numpy_to_pydf elif _check_for_pandas(first_element) and isinstance( - first_element, (pd.Series, pd.DatetimeIndex) + first_element, (pd.Series, pd.Index, pd.DatetimeIndex) ): to_pydf = _sequence_of_pandas_to_pydf @@ -1253,7 +1286,7 @@ def _sequence_of_numpy_to_pydf( def _sequence_of_pandas_to_pydf( - first_element: pd.Series[Any] | pd.DatetimeIndex, + first_element: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, data: Sequence[Any], schema: SchemaDefinition | None, schema_overrides: SchemaDict | None, @@ -1541,19 +1574,21 @@ def numpy_to_pydf( def arrow_to_pydf( - data: pa.Table, + data: pa.Table | pa.RecordBatch, schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, rechunk: bool = True, ) -> PyDataFrame: - """Construct a PyDataFrame from an Arrow Table.""" + """Construct a PyDataFrame from an Arrow Table or RecordBatch.""" original_schema = schema column_names, schema_overrides = _unpack_schema( (schema or data.column_names), schema_overrides=schema_overrides ) try: if column_names != data.column_names: + if isinstance(data, pa.RecordBatch): + data = pa.Table.from_batches([data]) data = data.rename_columns(column_names) except pa.lib.ArrowInvalid as e: msg = "dimensions of columns arg must match data dimensions" @@ -1809,10 +1844,10 @@ def pandas_to_pydf( ) -def coerce_arrow(array: pa.Array, *, rechunk: bool = True) -> pa.Array: +def coerce_arrow(array: pa.Array) -> pa.Array: import pyarrow.compute as pc - if hasattr(array, "num_chunks") and array.num_chunks > 1 and rechunk: + if hasattr(array, "num_chunks") and array.num_chunks > 1: # small integer keys can often not be combined, so let's already cast # to the uint32 used by polars if pa.types.is_dictionary(array.type) and ( diff --git a/py-polars/polars/utils/_polars_version.py b/py-polars/polars/utils/_polars_version.py new file mode 100644 index 000000000000..1a7da238e85f --- /dev/null +++ b/py-polars/polars/utils/_polars_version.py @@ -0,0 +1,19 @@ +try: + import polars.polars as plr + + _POLARS_VERSION = plr.__version__ +except ImportError: + # This is only useful for documentation + import warnings + + warnings.warn("Polars binary is missing!", stacklevel=2) + _POLARS_VERSION = "" + + +def get_polars_version() -> str: + """ + Return the version of the Python Polars package as a string. + + If the Polars binary is missing, returns an empty string. + """ + return _POLARS_VERSION diff --git a/py-polars/polars/utils/build_info.py b/py-polars/polars/utils/build_info.py deleted file mode 100644 index efc2c447dddf..000000000000 --- a/py-polars/polars/utils/build_info.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from polars.utils.polars_version import get_polars_version - -try: - from polars.polars import _build_info_ -except ImportError: - _build_info_ = {} - -_build_info_["version"] = get_polars_version() or "" - - -def build_info() -> dict[str, Any]: - """ - Return a dict with Polars build information. - - If Polars was compiled with "build_info" feature gate return the full build info, - otherwise only version is included. The full build information dict contains - the following keys ['build', 'info-time', 'dependencies', 'features', 'host', - 'target', 'git', 'version']. - """ - return _build_info_ diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py index 638506622fe5..7008f82235ff 100644 --- a/py-polars/polars/utils/convert.py +++ b/py-polars/polars/utils/convert.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from datetime import datetime, time, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from decimal import Context from functools import lru_cache from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, overload @@ -10,7 +10,7 @@ if TYPE_CHECKING: from collections.abc import Reversible - from datetime import date, tzinfo + from datetime import tzinfo from decimal import Decimal from polars.type_aliases import TimeUnit @@ -51,6 +51,7 @@ def get_zoneinfo(key: str) -> ZoneInfo: # noqa: D103 US_PER_SECOND = 1_000_000 MS_PER_SECOND = 1_000 +EPOCH_DATE = date(1970, 1, 1) EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) @@ -108,14 +109,13 @@ def _time_to_pl_time(t: time) -> int: def _date_to_pl_date(d: date) -> int: - dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc) - return int(dt.timestamp()) // SECONDS_PER_DAY + return (d - EPOCH_DATE).days def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: """Convert a python datetime to a timestamp in given time unit.""" if dt.tzinfo is None: - # Make sure to use UTC rather than system time zone. + # Make sure to use UTC rather than system time zone dt = dt.replace(tzinfo=timezone.utc) microseconds = dt.microsecond seconds = _timestamp_in_seconds(dt) @@ -218,8 +218,8 @@ def _localize(dt: datetime, time_zone: str) -> datetime: return dt.astimezone(_tzinfo) -def _datetime_for_anyvalue(dt: datetime) -> tuple[int, int]: - """Used in pyo3 anyvalue conversion.""" +def _datetime_for_any_value(dt: datetime) -> tuple[int, int]: + """Used in PyO3 AnyValue conversion.""" # returns (s, ms) if dt.tzinfo is None: return ( @@ -229,8 +229,8 @@ def _datetime_for_anyvalue(dt: datetime) -> tuple[int, int]: return (_timestamp_in_seconds(dt), dt.microsecond) -def _datetime_for_anyvalue_windows(dt: datetime) -> tuple[float, int]: - """Used in pyo3 anyvalue conversion.""" +def _datetime_for_any_value_windows(dt: datetime) -> tuple[float, int]: + """Used in PyO3 AnyValue conversion.""" if dt.tzinfo is None: dt = _localize(dt, "UTC") # returns (s, ms) diff --git a/py-polars/polars/utils/deprecation.py b/py-polars/polars/utils/deprecation.py index a0b6ad6793d8..a95d711ebc5f 100644 --- a/py-polars/polars/utils/deprecation.py +++ b/py-polars/polars/utils/deprecation.py @@ -92,17 +92,20 @@ def myfunc(new_name): def decorate(function: Callable[P, T]) -> Callable[P, T]: @wraps(function) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - if param_args := kwargs.pop(old_name, []): - issue_deprecation_warning( - f"named `{old_name}` param is deprecated; use positional `*args` instead.", - version=version, - ) - if param_args: - if not isinstance(param_args, Sequence) or isinstance(param_args, str): - param_args = (param_args,) - elif not isinstance(param_args, tuple): - param_args = tuple(param_args) - args = args + param_args # type: ignore[assignment] + try: + param_args = kwargs.pop(old_name) + except KeyError: + return function(*args, **kwargs) + + issue_deprecation_warning( + f"named `{old_name}` param is deprecated; use positional `*args` instead.", + version=version, + ) + if not isinstance(param_args, Sequence) or isinstance(param_args, str): + param_args = (param_args,) + elif not isinstance(param_args, tuple): + param_args = tuple(param_args) + args = args + param_args # type: ignore[assignment] return function(*args, **kwargs) wrapper.__signature__ = inspect.signature(function) # type: ignore[attr-defined] diff --git a/py-polars/polars/utils/meta.py b/py-polars/polars/utils/meta.py deleted file mode 100644 index 62d5e5b65ee5..000000000000 --- a/py-polars/polars/utils/meta.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Various public utility functions.""" -from __future__ import annotations - -import contextlib -from typing import TYPE_CHECKING - -with contextlib.suppress(ImportError): # Module not available when building docs - import polars.polars as plr - -if TYPE_CHECKING: - from polars.datatypes import DataType - - -def get_index_type() -> DataType: - """ - Get the datatype used for Polars indexing. - - Returns - ------- - DataType - :class:`UInt32` in regular Polars, :class:`UInt64` in bigidx Polars. - - Examples - -------- - >>> pl.get_index_type() - UInt32 - """ - return plr.get_index_type() - - -def threadpool_size() -> int: - """ - Get the number of threads in the Polars thread pool. - - Notes - ----- - The threadpool size can be overridden by setting the `POLARS_MAX_THREADS` - environment variable before process start. (The thread pool is not behind a - lock, so it cannot be modified once set). A reasonable use-case for this might - be temporarily setting max threads to a low value before importing polars in a - pyspark UDF or similar context. Otherwise, it is strongly recommended not to - override this value as it will be set automatically by the engine. - - Examples - -------- - >>> pl.threadpool_size() # doctest: +SKIP - 24 - """ - return plr.threadpool_size() diff --git a/py-polars/polars/utils/polars_version.py b/py-polars/polars/utils/polars_version.py deleted file mode 100644 index 9f3d3360507e..000000000000 --- a/py-polars/polars/utils/polars_version.py +++ /dev/null @@ -1,19 +0,0 @@ -try: - from polars.polars import get_polars_version as _get_polars_version - - polars_version_string = _get_polars_version() -except ImportError: - # this is only useful for documentation - import warnings - - warnings.warn("polars binary missing!", stacklevel=2) - polars_version_string = "" - - -def get_polars_version() -> str: - """ - Return the version of the Python Polars package as a string. - - If the Polars binary is missing, returns an empty string. - """ - return polars_version_string diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 6c98f053c452..57c58094dbef 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -25,6 +25,8 @@ Union, ) +from polars.utils.various import re_escape + if TYPE_CHECKING: from dis import Instruction @@ -132,29 +134,40 @@ class OpNames: ) ) -# python functions that we can map to native expressions +# python attrs/funcs that map to native expressions +_PYTHON_ATTRS_MAP = { + "date": "dt.date()", + "day": "dt.day()", + "hour": "dt.hour()", + "microsecond": "dt.microsecond()", + "minute": "dt.minute()", + "month": "dt.month()", + "second": "dt.second()", + "year": "dt.year()", +} _PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"} _PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"} _PYTHON_METHODS_MAP = { + # string + "endswith": "str.ends_with", "lower": "str.to_lowercase", + "lstrip": "str.strip_chars_start", + "rstrip": "str.strip_chars_end", + "startswith": "str.starts_with", + "strip": "str.strip_chars", "title": "str.to_titlecase", "upper": "str.to_uppercase", + # temporal + "date": "dt.date", + "isoweekday": "dt.weekday", + "time": "dt.time", } -_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [ - # lambda x: module.func(CONSTANT) - { - "argument_1_opname": [{"LOAD_CONST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [_NUMPY_MODULE_ALIASES], - "attribute_name": [], - "function_name": [_NUMPY_FUNCTIONS], - }, - # lambda x: module.func(x) +_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [ + # lambda x: numpy.func(x) + # lambda x: numpy.func(CONSTANT) { - "argument_1_opname": [{"LOAD_FAST"}], + "argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}], "argument_2_opname": [], "module_opname": [OpNames.LOAD_ATTR], "attribute_opname": [], @@ -162,6 +175,7 @@ class OpNames: "attribute_name": [], "function_name": [_NUMPY_FUNCTIONS], }, + # lambda x: json.loads(x) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [], @@ -171,7 +185,7 @@ class OpNames: "attribute_name": [], "function_name": [{"loads"}], }, - # lambda x: module.func(x, CONSTANT) + # lambda x: datetime.strptime(x, CONSTANT) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [{"LOAD_CONST"}], @@ -194,13 +208,12 @@ class OpNames: ] # In addition to `lambda x: func(x)`, also support cases when a unary operation # has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`. -_FUNCTION_KINDS = [ - # Dict entry 1 has incompatible type "str": "object"; - # expected "str": "list[AbstractSet[str]]" +_MODULE_FUNCTIONS = [ {**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item] - for kind in _FUNCTION_KINDS + for kind in _MODULE_FUNCTIONS for unary in [[set(OpNames.UNARY)], []] ] +_RE_IMPLICIT_BOOL = re.compile(r'pl\.col\("([^"]*)"\) & pl\.col\("\1"\)\.(.+)') def _get_all_caller_variables() -> dict[str, Any]: @@ -252,6 +265,12 @@ def __init__(self, function: Callable[[Any], Any], map_target: MapTarget): instructions=original_instructions, ) + def _omit_implicit_bool(self, expr: str) -> str: + """Drop extraneous/implied bool (eg: `pl.col("d") & pl.col("d").dt.date()`).""" + while _RE_IMPLICIT_BOOL.search(expr): + expr = _RE_IMPLICIT_BOOL.sub(repl=r'pl.col("\1").\2', string=expr) + return expr + @staticmethod def _get_param_name(function: Callable[[Any], Any]) -> str | None: """Return single function parameter name.""" @@ -415,11 +434,13 @@ def to_expression(self, col: str) -> str | None: # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn if "pl.col(" not in polars_expr: return None - elif self._map_target == "series": - target_name = self._get_target_name(col, polars_expr) - return polars_expr.replace(f'pl.col("{col}")', target_name) else: - return polars_expr + polars_expr = self._omit_implicit_bool(polars_expr) + if self._map_target == "series": + target_name = self._get_target_name(col, polars_expr) + return polars_expr.replace(f'pl.col("{col}")', target_name) + else: + return polars_expr def warn( self, @@ -562,7 +583,7 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str # But, if e1 << e2 was valid, then e2 must have been positive. # Hence, the output of 2**e2 can be safely cast to Int64, which # may be necessary if chaining operations which assume Int64 output. - return f"({e1}*2**{e2}).cast(pl.Int64)" + return f"({e1} * 2**{e2}).cast(pl.Int64)" elif op == ">>": # Motivation for the cast is the same as in the '<<' case above. return f"({e1} / 2**{e2}).cast(pl.Int64)" @@ -656,7 +677,8 @@ def _matches( idx: int, *, opnames: list[AbstractSet[str]], - argvals: list[AbstractSet[Any] | dict[Any, Any]] | None, + argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None, + is_attr: bool = False, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -669,9 +691,19 @@ def _matches( The full opname sequence that defines a match. argvals Associated argvals that must also match (in same position as opnames). + is_attr + Indicate if the match represents pure attribute access (cannot be called). """ n_required_ops, argvals = len(opnames), argvals or [] - instructions = self._instructions[idx : idx + n_required_ops] + idx_offset = idx + n_required_ops + if ( + is_attr + and (trailing_inst := self._instructions[idx_offset : idx_offset + 1]) + and trailing_inst[0].opname in OpNames.CALL # not pure attr if called + ): + return [] + + instructions = self._instructions[idx:idx_offset] if len(instructions) == n_required_ops and all( inst.opname in match_opnames and (match_argval is None or inst.argval in match_argval) @@ -702,12 +734,30 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: self._rewrite_functions, self._rewrite_methods, self._rewrite_builtins, + self._rewrite_attrs, ) ): updated_instructions.append(inst) idx += increment or 1 return updated_instructions + def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int: + """Replace python attribute lookup with synthetic POLARS_EXPRESSION op.""" + if matching_instructions := self._matches( + idx, + opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}], + argvals=[None, _PYTHON_ATTRS_MAP], + is_attr=True, + ): + inst = matching_instructions[1] + expr_name = _PYTHON_ATTRS_MAP[inst.argval] + px = inst._replace( + opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name + ) + updated_instructions.extend([matching_instructions[0], px]) + + return len(matching_instructions) + def _rewrite_builtins( self, idx: int, updated_instructions: list[Instruction] ) -> int: @@ -722,7 +772,7 @@ def _rewrite_builtins( dtype = _PYTHON_CASTS_MAP[argval] argval = f"cast(pl.{dtype})" - synthetic_call = inst1._replace( + px = inst1._replace( opname="POLARS_EXPRESSION", argval=argval, argrepr=argval, @@ -730,7 +780,7 @@ def _rewrite_builtins( ) # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order operand = inst2._replace(offset=inst1.offset) - updated_instructions.extend((operand, synthetic_call)) + updated_instructions.extend((operand, px)) return len(matching_instructions) @@ -738,7 +788,7 @@ def _rewrite_functions( self, idx: int, updated_instructions: list[Instruction] ) -> int: """Replace function calls with a synthetic POLARS_EXPRESSION op.""" - for function_kind in _FUNCTION_KINDS: + for function_kind in _MODULE_FUNCTIONS: opnames: list[AbstractSet[str]] = [ {"LOAD_GLOBAL", "LOAD_DEREF"}, *function_kind["module_opname"], @@ -775,22 +825,24 @@ def _rewrite_functions( return 0 else: expr_name = inst2.argval - synthetic_call = inst1._replace( + + px = inst1._replace( opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name, offset=inst3.offset, ) + # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order operand = inst3._replace(offset=inst1.offset) updated_instructions.extend( ( operand, matching_instructions[3 + attribute_count], - synthetic_call, + px, ) if function_kind["argument_1_unary_opname"] - else (operand, synthetic_call) + else (operand, px) ) return len(matching_instructions) @@ -800,20 +852,40 @@ def _rewrite_methods( self, idx: int, updated_instructions: list[Instruction] ) -> int: """Replace python method calls with synthetic POLARS_EXPRESSION op.""" - if matching_instructions := self._matches( - idx, - opnames=[ - OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"}, - OpNames.CALL, - ], - argvals=[_PYTHON_METHODS_MAP], + LOAD_METHOD = OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"} + if matching_instructions := ( + # method call with one basic arg, eg: "s.endswith('!')" + self._matches( + idx, + opnames=[LOAD_METHOD, {"LOAD_CONST"}, OpNames.CALL], + argvals=[_PYTHON_METHODS_MAP], + ) + or + # method call with no arg, eg: "s.lower()" + self._matches( + idx, + opnames=[LOAD_METHOD, OpNames.CALL], + argvals=[_PYTHON_METHODS_MAP], + ) ): inst = matching_instructions[0] - expr_name = _PYTHON_METHODS_MAP[inst.argval] - synthetic_call = inst._replace( - opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name - ) - updated_instructions.append(synthetic_call) + expr = _PYTHON_METHODS_MAP[inst.argval] + + if matching_instructions[1].opname == "LOAD_CONST": + param_value = matching_instructions[1].argval + if isinstance(param_value, tuple) and expr in ( + "str.starts_with", + "str.ends_with", + ): + starts, ends = ("^", "") if "starts" in expr else ("", "$") + rx = "|".join(re_escape(v) for v in param_value) + q = '"' if "'" in param_value else "'" + expr = f"str.contains(r{q}{starts}({rx}){ends}{q})" + else: + expr += f"({param_value!r})" + + px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr) + updated_instructions.append(px) return len(matching_instructions) diff --git a/py-polars/polars/utils/various.py b/py-polars/polars/utils/various.py index 47d9843fec56..d06e604ead34 100644 --- a/py-polars/polars/utils/various.py +++ b/py-polars/polars/utils/various.py @@ -399,6 +399,9 @@ def __get__( # type: ignore[override] return self # type: ignore[return-value] +BUILDING_SPHINX_DOCS = os.getenv("BUILDING_SPHINX_DOCS") + + class _NoDefault(Enum): # "borrowed" from # https://github.com/pandas-dev/pandas/blob/e7859983a814b1823cf26e3b491ae2fa3be47c53/pandas/_libs/lib.pyx#L2736-L2748 @@ -559,3 +562,11 @@ def parse_percentiles( at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles] return [*sub_50_percentiles, *at_or_above_50_percentiles] + + +def re_escape(s: str) -> str: + """Escape a string for use in a Polars (Rust) regex.""" + # note: almost the same as the standard python 're.escape' function, but + # escapes _only_ those metachars with meaning to the rust regex crate + re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-" + return re.sub(f"([{re_rust_metachars}])", r"\\\1", s) diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index fc19c4f34d31..4911687cea2e 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -108,7 +108,6 @@ ignore_missing_imports = true module = [ "IPython.*", "matplotlib.*", - "dataframe_api_compat.*", ] follow_imports = "skip" @@ -125,6 +124,7 @@ warn_return_any = false line-length = 88 fix = true +[tool.ruff.lint] select = [ "E", # pycodestyle "W", # pycodestyle @@ -178,23 +178,23 @@ ignore = [ "W191", ] -[tool.ruff.format] -docstring-code-format = true +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"] -[tool.ruff.pycodestyle] +[tool.ruff.lint.pycodestyle] max-doc-length = 88 [tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.flake8-type-checking] +[tool.ruff.lint.flake8-type-checking] strict = true -[tool.ruff.per-file-ignores] -"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"] +[tool.ruff.format] +docstring-code-format = true [tool.pytest.ini_options] addopts = [ @@ -223,6 +223,9 @@ filterwarnings = [ "ignore:datetime.datetime.utcnow\\(\\) is deprecated.*:DeprecationWarning", # Introspection under PyCharm IDE can generate this in Python 3.12 "ignore:.*co_lnotab is deprecated, use co_lines.*:DeprecationWarning", + # TODO: Excel tests lead to unclosed file warnings + # https://github.com/pola-rs/polars/issues/14466 + "ignore:unclosed file.*:ResourceWarning", ] xfail_strict = true diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index c7a0592bf09b..4a3785e77f15 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -2,8 +2,6 @@ # We're not pinning package dependencies, because our tests need to pass with the # latest version of the packages. ---prefer-binary - # ----- # BUILD # ----- @@ -37,14 +35,12 @@ s3fs[boto3] # Spreadsheet ezodf lxml -fastexcel>=0.7.0; platform_system != 'Windows' +fastexcel>=0.8.0 openpyxl pyxlsb xlsx2csv XlsxWriter deltalake>=0.14.0 -# Dataframe interchange protocol -dataframe-api-compat>=0.1.6 pyiceberg>=0.5.0 # Csv zstandard @@ -58,13 +54,13 @@ gevent # TOOLING # ------- -hypothesis==6.92.1 -pytest==7.4.0 +hypothesis==6.97.4 +pytest==8.0.0 pytest-cov==4.1.0 pytest-xdist==3.5.0 # Need moto.server to mock s3fs - see: https://github.com/aio-libs/aiobotocore/issues/755 -moto[s3]==4.2.2 +moto[s3]==5.0.0 flask flask-cors diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 2bede547ead6..225616bb2c75 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,3 +1,3 @@ mypy==1.8.0 -ruff==0.1.9 -typos==1.16.21 +ruff==0.2.0 +typos==1.17.2 diff --git a/py-polars/src/arrow_interop/to_rust.rs b/py-polars/src/arrow_interop/to_rust.rs index 612cc6445098..ecdd13e1a364 100644 --- a/py-polars/src/arrow_interop/to_rust.rs +++ b/py-polars/src/arrow_interop/to_rust.rs @@ -98,7 +98,7 @@ pub fn to_rust_df(rb: &[&PyAny]) -> PyResult { }?; // no need to check as a record batch has the same guarantees - Ok(DataFrame::new_no_checks(columns)) + Ok(unsafe { DataFrame::new_no_checks(columns) }) }) .collect::>>()?; diff --git a/py-polars/src/batched_csv.rs b/py-polars/src/batched_csv.rs index 82505156c635..223af6b96350 100644 --- a/py-polars/src/batched_csv.rs +++ b/py-polars/src/batched_csv.rs @@ -142,7 +142,7 @@ impl PyBatchedCsv { } .map_err(PyPolarsErr::from)?; - // safety: same memory layout + // SAFETY: same memory layout let batches = unsafe { std::mem::transmute::>, Option>>(batches) }; diff --git a/py-polars/src/conversion/anyvalue.rs b/py-polars/src/conversion/any_value.rs similarity index 99% rename from py-polars/src/conversion/anyvalue.rs rename to py-polars/src/conversion/any_value.rs index 35df48cf2a30..a66ec63d5354 100644 --- a/py-polars/src/conversion/anyvalue.rs +++ b/py-polars/src/conversion/any_value.rs @@ -362,7 +362,7 @@ fn convert_datetime(ob: &PyAny) -> PyResult> { #[cfg(target_arch = "windows")] let (seconds, microseconds) = { let convert = UTILS - .getattr(py, intern!(py, "_datetime_for_anyvalue_windows")) + .getattr(py, intern!(py, "_datetime_for_any_value_windows")) .unwrap(); let out = convert.call1(py, (ob,)).unwrap(); let out: (i64, i64) = out.extract(py).unwrap(); @@ -372,7 +372,7 @@ fn convert_datetime(ob: &PyAny) -> PyResult> { #[cfg(not(target_arch = "windows"))] let (seconds, microseconds) = { let convert = UTILS - .getattr(py, intern!(py, "_datetime_for_anyvalue")) + .getattr(py, intern!(py, "_datetime_for_any_value")) .unwrap(); let out = convert.call1(py, (ob,)).unwrap(); let out: (i64, i64) = out.extract(py).unwrap(); diff --git a/py-polars/src/conversion/chunked_array.rs b/py-polars/src/conversion/chunked_array.rs index af0bd1d023f6..914b8a8017ea 100644 --- a/py-polars/src/conversion/chunked_array.rs +++ b/py-polars/src/conversion/chunked_array.rs @@ -1,9 +1,6 @@ -use polars::prelude::AnyValue; -#[cfg(feature = "cloud")] -use pyo3::conversion::{FromPyObject, IntoPy}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyList, PyTuple}; -use pyo3::{intern, PyAny, PyResult}; use super::{decimal_to_digits, struct_dict}; use crate::prelude::*; @@ -141,16 +138,21 @@ impl ToPyObject for Wrap<&DatetimeChunked> { impl ToPyObject for Wrap<&TimeChunked> { fn to_object(&self, py: Python) -> PyObject { - let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_time")).unwrap(); - let iter = self - .0 - .into_iter() - .map(|opt_v| opt_v.map(|v| convert.call1((v,)).unwrap())); + let iter = time_to_pyobject_iter(py, self.0); PyList::new(py, iter).into_py(py) } } +pub(crate) fn time_to_pyobject_iter<'a>( + py: Python<'a>, + ca: &'a TimeChunked, +) -> impl ExactSizeIterator> { + let utils = UTILS.as_ref(py); + let convert = utils.getattr(intern!(py, "_to_python_time")).unwrap(); + ca.0.into_iter() + .map(|opt_v| opt_v.map(|v| convert.call1((v,)).unwrap())) +} + impl ToPyObject for Wrap<&DateChunked> { fn to_object(&self, py: Python) -> PyObject { let utils = UTILS.as_ref(py); @@ -165,29 +167,36 @@ impl ToPyObject for Wrap<&DateChunked> { impl ToPyObject for Wrap<&DecimalChunked> { fn to_object(&self, py: Python) -> PyObject { - let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_decimal")).unwrap(); - let py_scale = (-(self.0.scale() as i32)).to_object(py); - // if we don't know precision, the only safe bet is to set it to 39 - let py_precision = self.0.precision().unwrap_or(39).to_object(py); - let iter = self.0.into_iter().map(|opt_v| { - opt_v.map(|v| { - // TODO! use anyvalue so that we have a single impl. - const N: usize = 3; - let mut buf = [0_u128; N]; - let n_digits = decimal_to_digits(v.abs(), &mut buf); - let buf = unsafe { - std::slice::from_raw_parts( - buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), - ) - }; - let digits = PyTuple::new(py, buf.iter().take(n_digits)); - convert - .call1((v.is_negative() as u8, digits, &py_precision, &py_scale)) - .unwrap() - }) - }); + let iter = decimal_to_pyobject_iter(py, self.0); PyList::new(py, iter).into_py(py) } } + +pub(crate) fn decimal_to_pyobject_iter<'a>( + py: Python<'a>, + ca: &'a DecimalChunked, +) -> impl ExactSizeIterator> { + let utils = UTILS.as_ref(py); + let convert = utils.getattr(intern!(py, "_to_python_decimal")).unwrap(); + let py_scale = (-(ca.scale() as i32)).to_object(py); + // if we don't know precision, the only safe bet is to set it to 39 + let py_precision = ca.precision().unwrap_or(39).to_object(py); + ca.into_iter().map(move |opt_v| { + opt_v.map(|v| { + // TODO! use AnyValue so that we have a single impl. + const N: usize = 3; + let mut buf = [0_u128; N]; + let n_digits = decimal_to_digits(v.abs(), &mut buf); + let buf = unsafe { + std::slice::from_raw_parts( + buf.as_slice().as_ptr() as *const u8, + N * std::mem::size_of::(), + ) + }; + let digits = PyTuple::new(py, buf.iter().take(n_digits)); + convert + .call1((v.is_negative() as u8, digits, &py_precision, &py_scale)) + .unwrap() + }) + }) +} diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index d0cf23eac2a0..971b9be12dac 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -1,5 +1,5 @@ -pub(crate) mod anyvalue; -mod chunked_array; +pub(crate) mod any_value; +pub(crate) mod chunked_array; use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -10,23 +10,18 @@ use polars::frame::row::Row; use polars::frame::NullStrategy; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -#[cfg(feature = "ipc")] -use polars::io::ipc::IpcCompression; -use polars::prelude::AnyValue; use polars::series::ops::NullBehavior; -use polars_core::prelude::{IndexOrder, QuantileInterpolOptions}; use polars_core::utils::arrow::array::Array; use polars_core::utils::arrow::types::NativeType; use polars_lazy::prelude::*; #[cfg(feature = "cloud")] use polars_rs::io::cloud::CloudOptions; -use polars_utils::total_ord::TotalEq; +use polars_utils::total_ord::{TotalEq, TotalHash}; use pyo3::basic::CompareOp; -use pyo3::conversion::{FromPyObject, IntoPy}; use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PySequence}; -use pyo3::{intern, PyAny, PyResult}; use smartstring::alias::String as SmartString; use crate::error::PyPolarsErr; @@ -38,19 +33,19 @@ use crate::series::PySeries; use crate::{PyDataFrame, PyLazyFrame}; pub(crate) fn slice_to_wrapped(slice: &[T]) -> &[Wrap] { - // Safety: + // SAFETY: // Wrap is transparent. unsafe { std::mem::transmute(slice) } } pub(crate) fn slice_extract_wrapped(slice: &[Wrap]) -> &[T] { - // Safety: + // SAFETY: // Wrap is transparent. unsafe { std::mem::transmute(slice) } } pub(crate) fn vec_extract_wrapped(buf: Vec>) -> Vec { - // Safety: + // SAFETY: // Wrap is transparent. unsafe { std::mem::transmute(buf) } } @@ -129,7 +124,7 @@ fn struct_dict<'a>( // accept u128 array to ensure alignment is correct fn decimal_to_digits(v: i128, buf: &mut [u128; 3]) -> usize { const ZEROS: i128 = 0x3030_3030_3030_3030_3030_3030_3030_3030; - // safety: transmute is safe as there are 48 bytes in 3 128bit ints + // SAFETY: transmute is safe as there are 48 bytes in 3 128bit ints // and the minimal alignment of u8 fits u16 let buf = unsafe { std::mem::transmute::<&mut [u128; 3], &mut [u8; 48]>(buf) }; let mut buffer = itoa::Buffer::new(); @@ -498,6 +493,15 @@ impl TotalEq for ObjectValue { } } +impl TotalHash for ObjectValue { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } +} + impl Display for ObjectValue { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.inner) diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 83a2a6b62c58..b357954eb98f 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -3,22 +3,15 @@ use std::num::NonZeroUsize; use std::ops::Deref; use either::Either; -use numpy::IntoPyArray; use polars::frame::row::{rows_to_schema_supertypes, Row}; -use polars::frame::NullStrategy; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -#[cfg(feature = "ipc")] -use polars::io::ipc::IpcCompression; use polars::io::mmap::ReaderBytes; use polars::io::RowIndex; use polars::prelude::*; use polars_core::export::arrow::datatypes::IntegerType; -use polars_core::frame::explode::MeltArgs; use polars_core::frame::*; -use polars_core::prelude::IndexOrder; use polars_core::utils::arrow::compute::cast::CastOptions; -use polars_core::utils::try_get_supertype; #[cfg(feature = "pivot")] use polars_lazy::frame::pivot::{pivot, pivot_stable}; use pyo3::prelude::*; @@ -783,7 +776,7 @@ impl PyDataFrame { s.get_object(idx).map(|any| any.into()); obj.to_object(py) }, - // safety: we are in bounds. + // SAFETY: we are in bounds. _ => unsafe { Wrap(s.get_unchecked(idx)).into_py(py) }, }), ) @@ -793,32 +786,6 @@ impl PyDataFrame { }) } - pub fn to_numpy(&self, py: Python, order: Wrap) -> Option { - let mut st = None; - for s in self.df.iter() { - let dt_i = s.dtype(); - match st { - None => st = Some(dt_i.clone()), - Some(ref mut st) => { - *st = try_get_supertype(st, dt_i).ok()?; - }, - } - } - let st = st?; - - #[rustfmt::skip] - let pyarray = match st { - DataType::UInt32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - DataType::UInt64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - DataType::Int32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - DataType::Int64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - DataType::Float32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - DataType::Float64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), - _ => return None, - }; - Some(pyarray) - } - #[cfg(feature = "parquet")] #[pyo3(signature = (py_f, compression, compression_level, statistics, row_group_size, data_page_size))] pub fn write_parquet( @@ -1251,11 +1218,12 @@ impl PyDataFrame { } #[cfg(feature = "pivot")] + #[pyo3(signature = (index, columns, values, maintain_order, sort_columns, aggregate_expr, separator))] pub fn pivot_expr( &self, - values: Vec, index: Vec, columns: Vec, + values: Option>, maintain_order: bool, sort_columns: bool, aggregate_expr: Option, @@ -1265,9 +1233,9 @@ impl PyDataFrame { let agg_expr = aggregate_expr.map(|expr| expr.inner); let df = fun( &self.df, - values, index, columns, + values, sort_columns, agg_expr, separator, @@ -1400,7 +1368,11 @@ impl PyDataFrame { } #[pyo3(signature = (keep_names_as, column_names))] - pub fn transpose(&self, keep_names_as: Option<&str>, column_names: &PyAny) -> PyResult { + pub fn transpose( + &mut self, + keep_names_as: Option<&str>, + column_names: &PyAny, + ) -> PyResult { let new_col_names = if let Ok(name) = column_names.extract::>() { Some(Either::Right(name)) } else if let Ok(name) = column_names.extract::() { diff --git a/py-polars/src/expr/array.rs b/py-polars/src/expr/array.rs index bfc6c9dbccd5..5b0cb2bf365b 100644 --- a/py-polars/src/expr/array.rs +++ b/py-polars/src/expr/array.rs @@ -1,5 +1,8 @@ use polars::prelude::*; +use polars_ops::prelude::array::ArrToStructNameGenerator; +use pyo3::prelude::*; use pyo3::pymethods; +use smartstring::alias::String as SmartString; use crate::expr::PyExpr; @@ -17,6 +20,18 @@ impl PyExpr { self.inner.clone().arr().sum().into() } + fn arr_std(&self, ddof: u8) -> Self { + self.inner.clone().arr().std(ddof).into() + } + + fn arr_var(&self, ddof: u8) -> Self { + self.inner.clone().arr().var(ddof).into() + } + + fn arr_median(&self) -> Self { + self.inner.clone().arr().median().into() + } + fn arr_unique(&self, maintain_order: bool) -> Self { if maintain_order { self.inner.clone().arr().unique_stable().into() @@ -82,4 +97,23 @@ impl PyExpr { fn arr_count_matches(&self, expr: PyExpr) -> Self { self.inner.clone().arr().count_matches(expr.inner).into() } + + #[pyo3(signature = (name_gen))] + fn arr_to_struct(&self, name_gen: Option) -> PyResult { + let name_gen = name_gen.map(|lambda| { + Arc::new(move |idx: usize| { + Python::with_gil(|py| { + let out = lambda.call1(py, (idx,)).unwrap(); + let out: SmartString = out.extract::<&str>(py).unwrap().into(); + out + }) + }) as ArrToStructNameGenerator + }); + + Ok(self.inner.clone().arr().to_struct(name_gen).into()) + } + + fn arr_shift(&self, n: PyExpr) -> Self { + self.inner.clone().arr().shift(n.inner).into() + } } diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index efe999661508..67d156e9e1c7 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -3,7 +3,6 @@ use std::ops::Neg; use polars::lazy::dsl; use polars::prelude::*; use polars::series::ops::NullBehavior; -use polars_core::prelude::QuantileInterpolOptions; use polars_core::series::IsSorted; use pyo3::class::basic::CompareOp; use pyo3::prelude::*; @@ -595,17 +594,18 @@ impl PyExpr { self.inner.clone().rolling(options).into() } - fn _and(&self, expr: Self) -> Self { + fn and_(&self, expr: Self) -> Self { self.inner.clone().and(expr.inner).into() } - fn _xor(&self, expr: Self) -> Self { - self.inner.clone().xor(expr.inner).into() + fn or_(&self, expr: Self) -> Self { + self.inner.clone().or(expr.inner).into() } - fn _or(&self, expr: Self) -> Self { - self.inner.clone().or(expr.inner).into() + fn xor_(&self, expr: Self) -> Self { + self.inner.clone().xor(expr.inner).into() } + #[cfg(feature = "is_in")] fn is_in(&self, expr: Self) -> Self { self.inner.clone().is_in(expr.inner).into() @@ -684,7 +684,7 @@ impl PyExpr { self.inner.clone().exclude(columns).into() } fn exclude_dtype(&self, dtypes: Vec>) -> Self { - // Safety: + // SAFETY: // Wrap is transparent. let dtypes: Vec = unsafe { std::mem::transmute(dtypes) }; self.inner.clone().exclude_dtype(&dtypes).into() @@ -804,20 +804,10 @@ impl PyExpr { }; self.inner.clone().ewm_var(options).into() } - fn extend_constant(&self, py: Python, value: Wrap, n: usize) -> Self { - let value = value.into_py(py); + fn extend_constant(&self, value: PyExpr, n: PyExpr) -> Self { self.inner .clone() - .apply( - move |s| { - Python::with_gil(|py| { - let value = value.extract::>(py).unwrap().0; - s.extend_constant(value, n).map(Some) - }) - }, - GetOutput::same_type(), - ) - .with_fmt("extend") + .extend_constant(value.inner, n.inner) .into() } diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index e17550f8ba9e..fde544a6ce41 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -1,4 +1,3 @@ -use polars::lazy::dsl::lit; use polars::prelude::*; use polars::series::ops::NullBehavior; use pyo3::prelude::*; @@ -74,6 +73,33 @@ impl PyExpr { .into() } + fn list_median(&self) -> Self { + self.inner + .clone() + .list() + .median() + .with_fmt("list.median") + .into() + } + + fn list_std(&self, ddof: u8) -> Self { + self.inner + .clone() + .list() + .std(ddof) + .with_fmt("list.std") + .into() + } + + fn list_var(&self, ddof: u8) -> Self { + self.inner + .clone() + .list() + .var(ddof) + .with_fmt("list.var") + .into() + } + fn list_min(&self) -> Self { self.inner.clone().list().min().into() } @@ -154,7 +180,16 @@ impl PyExpr { self.inner .clone() .list() - .take(index.inner, null_on_oob) + .gather(index.inner, null_on_oob) + .into() + } + + #[cfg(feature = "list_gather")] + fn list_gather_every(&self, n: PyExpr, offset: PyExpr) -> Self { + self.inner + .clone() + .list() + .gather_every(n.inner, offset.inner) .into() } @@ -187,6 +222,10 @@ impl PyExpr { .into()) } + fn list_n_unique(&self) -> Self { + self.inner.clone().list().n_unique().into() + } + fn list_unique(&self, maintain_order: bool) -> Self { let e = self.inner.clone(); diff --git a/py-polars/src/expr/meta.rs b/py-polars/src/expr/meta.rs index 658ebc6329a2..62af2805c203 100644 --- a/py-polars/src/expr/meta.rs +++ b/py-polars/src/expr/meta.rs @@ -89,7 +89,7 @@ impl PyExpr { } #[cfg(all(feature = "json", feature = "serde_json"))] - fn meta_write_json(&self, py_f: PyObject) -> PyResult<()> { + fn serialize(&self, py_f: PyObject) -> PyResult<()> { let file = BufWriter::new(get_file_like(py_f, true)?); serde_json::to_writer(file, &self.inner) .map_err(|err| PyValueError::new_err(format!("{err:?}")))?; @@ -97,17 +97,28 @@ impl PyExpr { } #[staticmethod] - fn meta_read_json(value: &str) -> PyResult { - #[cfg(feature = "json")] - { - let inner: polars_lazy::prelude::Expr = serde_json::from_str(value) - .map_err(|_| PyPolarsErr::from(polars_err!(ComputeError: "could not serialize")))?; - Ok(PyExpr { inner }) - } - #[cfg(not(feature = "json"))] - { - panic!("activate 'json' feature") - } + #[cfg(feature = "json")] + fn deserialize(py_f: PyObject) -> PyResult { + // it is faster to first read to memory and then parse: https://github.com/serde-rs/json/issues/160 + // so don't bother with files. + let mut json = String::new(); + let _ = get_file_like(py_f, false)? + .read_to_string(&mut json) + .unwrap(); + + // SAFETY: + // we skipped the serializing/deserializing of the static in lifetime in `DataType` + // so we actually don't have a lifetime at all when serializing. + + // &str still has a lifetime. But it's ok, because we drop it immediately + // in this scope + let json = unsafe { std::mem::transmute::<&'_ str, &'static str>(json.as_str()) }; + + let inner: polars_lazy::prelude::Expr = serde_json::from_str(json).map_err(|_| { + let msg = "could not deserialize input into an expression"; + PyPolarsErr::from(polars_err!(ComputeError: msg)) + })?; + Ok(PyExpr { inner }) } fn meta_tree_format(&self) -> PyResult { diff --git a/py-polars/src/expr/mod.rs b/py-polars/src/expr/mod.rs index 1ab5db4bfe66..2e329a850660 100644 --- a/py-polars/src/expr/mod.rs +++ b/py-polars/src/expr/mod.rs @@ -33,7 +33,7 @@ pub(crate) trait ToExprs { impl ToExprs for Vec { fn to_exprs(self) -> Vec { - // Safety + // SAFETY: // repr is transparent unsafe { std::mem::transmute(self) } } @@ -45,7 +45,7 @@ pub(crate) trait ToPyExprs { impl ToPyExprs for Vec { fn to_pyexprs(self) -> Vec { - // Safety + // SAFETY: // repr is transparent unsafe { std::mem::transmute(self) } } diff --git a/py-polars/src/expr/name.rs b/py-polars/src/expr/name.rs index 28b6686da6ad..821cab8fbefb 100644 --- a/py-polars/src/expr/name.rs +++ b/py-polars/src/expr/name.rs @@ -1,5 +1,6 @@ use polars::prelude::*; use pyo3::prelude::*; +use smartstring::alias::String as SmartString; use crate::PyExpr; @@ -40,4 +41,24 @@ impl PyExpr { fn name_to_uppercase(&self) -> Self { self.inner.clone().name().to_uppercase().into() } + + fn name_map_fields(&self, name_mapper: PyObject) -> Self { + let name_mapper = Arc::new(move |name: &str| { + Python::with_gil(|py| { + let out = name_mapper.call1(py, (name,)).unwrap(); + let out: SmartString = out.extract::<&str>(py).unwrap().into(); + out + }) + }) as FieldsNameMapper; + + self.inner.clone().name().map_fields(name_mapper).into() + } + + fn name_prefix_fields(&self, prefix: &str) -> Self { + self.inner.clone().name().prefix_fields(prefix).into() + } + + fn name_suffix_fields(&self, suffix: &str) -> Self { + self.inner.clone().name().suffix_fields(suffix).into() + } } diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 917aa4936ad9..128596bb13a9 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -1,7 +1,6 @@ use std::any::Any; use polars::prelude::*; -use polars_core::prelude::QuantileInterpolOptions; use pyo3::prelude::*; use pyo3::types::PyFloat; diff --git a/py-polars/src/file.rs b/py-polars/src/file.rs index a6454448a983..e3e8e7363ef8 100644 --- a/py-polars/src/file.rs +++ b/py-polars/src/file.rs @@ -3,6 +3,7 @@ use std::io; use std::io::{BufReader, Cursor, Read, Seek, SeekFrom, Write}; use polars::io::mmap::MmapBytesReader; +use polars_error::polars_warn; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString}; @@ -216,7 +217,7 @@ pub fn get_mmap_bytes_reader<'a>(py_f: &'a PyAny) -> PyResult) -> PyResult { let e = dsl::sum_horizontal(exprs).map_err(PyPolarsErr::from)?; Ok(e.into()) } + +#[pyfunction] +pub fn mean_horizontal(exprs: Vec) -> PyResult { + let exprs = exprs.to_exprs(); + let e = dsl::mean_horizontal(exprs).map_err(PyPolarsErr::from)?; + Ok(e.into()) +} diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index deda777bdb8f..704658f9d8cc 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -1,5 +1,4 @@ use polars::lazy::dsl; -use polars::lazy::dsl::Expr; use polars::prelude::*; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; @@ -8,7 +7,7 @@ use pyo3::types::{PyBool, PyBytes, PyFloat, PyInt, PyString}; use crate::conversion::{get_lf, Wrap}; use crate::expr::ToExprs; use crate::map::lazy::binary_lambda; -use crate::prelude::{vec_extract_wrapped, DataType, DatetimeArgs, DurationArgs, ObjectValue}; +use crate::prelude::{vec_extract_wrapped, ObjectValue}; use crate::{map, PyDataFrame, PyExpr, PyLazyFrame, PyPolarsErr, PySeries}; macro_rules! set_unwrapped_or_0 { @@ -452,7 +451,7 @@ pub fn repeat(value: PyExpr, n: PyExpr, dtype: Option>) -> PyResu } if let Expr::Literal(lv) = &value { - let av = lv.to_anyvalue().unwrap(); + let av = lv.to_any_value().unwrap(); // Integer inputs that fit in Int32 are parsed as such if let DataType::Int64 = av.dtype() { let int_value = av.try_extract::().unwrap(); diff --git a/py-polars/src/functions/meta.rs b/py-polars/src/functions/meta.rs index 1efed10763df..bc43657e1b12 100644 --- a/py-polars/src/functions/meta.rs +++ b/py-polars/src/functions/meta.rs @@ -6,19 +6,13 @@ use pyo3::prelude::*; use crate::conversion::Wrap; -const VERSION: &str = env!("CARGO_PKG_VERSION"); -#[pyfunction] -pub fn get_polars_version() -> &'static str { - VERSION -} - #[pyfunction] pub fn get_index_type(py: Python) -> PyObject { Wrap(IDX_DTYPE).to_object(py) } #[pyfunction] -pub fn threadpool_size() -> usize { +pub fn thread_pool_size() -> usize { POOL.current_num_threads() } diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index 050994411b4a..eeb9f28eb731 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -6,30 +6,15 @@ use std::num::NonZeroUsize; use std::path::PathBuf; pub use exitable::PyInProcessQuery; -#[cfg(feature = "csv")] -use polars::io::csv::SerializeOptions; use polars::io::RowIndex; -#[cfg(feature = "csv")] -use polars::lazy::frame::LazyCsvReader; -#[cfg(feature = "json")] -use polars::lazy::frame::LazyJsonLineReader; -use polars::lazy::frame::{AllowedOptimizations, LazyFrame}; -use polars::lazy::prelude::col; -#[cfg(feature = "csv")] -use polars::prelude::CsvEncoding; -use polars::prelude::{ClosedWindow, Field, JoinType, Schema}; use polars::time::*; -use polars_core::frame::explode::MeltArgs; -use polars_core::frame::UniqueKeepStrategy; use polars_core::prelude::*; -use polars_ops::prelude::AsOfOptions; use polars_rs::io::cloud::CloudOptions; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyList}; use crate::arrow_interop::to_rust::pyarrow_schema_to_rust; -use crate::conversion::Wrap; use crate::error::PyPolarsErr; use crate::expr::ToExprs; use crate::file::get_file_like; @@ -99,7 +84,7 @@ impl PyLazyFrame { .read_to_string(&mut json) .unwrap(); - // Safety + // SAFETY: // we skipped the serializing/deserializing of the static in lifetime in `DataType` // so we actually don't have a lifetime at all when serializing. @@ -378,6 +363,19 @@ impl PyLazyFrame { .map_err(PyPolarsErr::from)?; Ok(result) } + + fn describe_plan_tree(&self) -> String { + self.ldf.describe_plan_tree() + } + + fn describe_optimized_plan_tree(&self) -> PyResult { + let result = self + .ldf + .describe_optimized_plan_tree() + .map_err(PyPolarsErr::from)?; + Ok(result) + } + fn to_dot(&self, optimized: bool) -> PyResult { let result = self.ldf.to_dot(optimized).map_err(PyPolarsErr::from)?; Ok(result) diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index d65fd6951a8d..1dcb20557e1a 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -37,6 +37,7 @@ mod py_modules; mod series; #[cfg(feature = "sql")] mod sql; +mod to_numpy; mod utils; #[cfg(all(target_family = "unix", not(use_mimalloc)))] @@ -128,6 +129,8 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::sum_horizontal)) .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::mean_horizontal)) + .unwrap(); // Functions - lazy m.add_wrapped(wrap_pyfunction!(functions::arg_sort_by)) @@ -203,11 +206,9 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); // Functions - meta - m.add_wrapped(wrap_pyfunction!(functions::get_polars_version)) - .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::get_index_type)) .unwrap(); - m.add_wrapped(wrap_pyfunction!(functions::threadpool_size)) + m.add_wrapped(wrap_pyfunction!(functions::thread_pool_size)) .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::enable_string_cache)) .unwrap(); @@ -297,9 +298,10 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); // Build info + m.add("__version__", env!("CARGO_PKG_VERSION"))?; #[cfg(feature = "build_info")] m.add( - "_build_info_", + "__build__", pyo3_built!(py, build, "build", "time", "deps", "features", "host", "target", "git"), )?; diff --git a/py-polars/src/map/dataframe.rs b/py-polars/src/map/dataframe.rs index cd01d0b9eeb8..52882fd5db3d 100644 --- a/py-polars/src/map/dataframe.rs +++ b/py-polars/src/map/dataframe.rs @@ -1,14 +1,10 @@ use polars::prelude::*; use polars_core::frame::row::{rows_to_schema_first_non_null, Row}; use polars_core::series::SeriesIter; -use pyo3::conversion::{FromPyObject, IntoPy}; use pyo3::prelude::*; use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple}; use super::*; -use crate::conversion::Wrap; -use crate::error::PyPolarsErr; -use crate::series::PySeries; use crate::PyDataFrame; fn get_iters(df: &DataFrame) -> Vec { @@ -277,7 +273,7 @@ pub fn apply_lambda_with_rows_output<'a>( row_buf.0.push(v); } let ptr = &row_buf as *const Row; - // Safety: + // SAFETY: // we know that row constructor of polars dataframe does not keep a reference // to the row. Before we mutate the row buf again, the reference is dropped. // we only cannot prove it to the compiler. @@ -296,7 +292,7 @@ pub fn apply_lambda_with_rows_output<'a>( let schema = rows_to_schema_first_non_null(&buf, Some(50)); if init_null_count > 0 { - // Safety: we know the iterators size + // SAFETY: we know the iterators size let iter = unsafe { (0..init_null_count) .map(|_| Ok(&null_row)) @@ -306,7 +302,7 @@ pub fn apply_lambda_with_rows_output<'a>( }; DataFrame::try_from_rows_iter_and_schema(iter, &schema) } else { - // Safety: we know the iterators size + // SAFETY: we know the iterators size let iter = unsafe { buf.iter() .map(Ok) diff --git a/py-polars/src/map/lazy.rs b/py-polars/src/map/lazy.rs index 243772fa8d36..75084783a295 100644 --- a/py-polars/src/map/lazy.rs +++ b/py-polars/src/map/lazy.rs @@ -103,7 +103,7 @@ pub(crate) fn binary_lambda( let pyseries = if let Ok(expr) = result_series_wrapper.getattr(py, "_pyexpr") { let pyexpr = expr.extract::(py).unwrap(); let expr = pyexpr.inner; - let df = DataFrame::new_no_checks(vec![]); + let df = DataFrame::empty(); let out = df .lazy() .select([expr]) diff --git a/py-polars/src/map/series.rs b/py-polars/src/map/series.rs index b87660c02f7b..afe4172eb788 100644 --- a/py-polars/src/map/series.rs +++ b/py-polars/src/map/series.rs @@ -1,13 +1,10 @@ -use polars::chunked_array::builder::get_list_builder; use polars::prelude::*; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyCFunction, PyDict, PyFloat, PyList, PyString, PyTuple}; +use pyo3::types::{PyBool, PyCFunction, PyFloat, PyList, PyString, PyTuple}; use super::*; use crate::conversion::slice_to_wrapped; use crate::py_modules::SERIES; -use crate::series::PySeries; -use crate::{PyPolarsErr, Wrap}; /// Find the output type and dispatch to that implementation. fn infer_and_finish<'a, A: ApplyLambda<'a>>( @@ -49,7 +46,7 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( let py_pyseries = series.getattr(py, "_s").unwrap(); let series = py_pyseries.extract::(py).unwrap().series; - // empty dtype is incorrect use anyvalues. + // Empty dtype is incorrect, use AnyValues. if series.is_empty() { let av = out.extract::>()?; return applyer @@ -76,7 +73,7 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( .map(|ca| ca.into_series().into()); match result { Ok(out) => Ok(out), - // try anyvalue + // Try AnyValue Err(_) => { let av = out.extract::>()?; applyer @@ -126,9 +123,6 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( pub trait ApplyLambda<'a> { fn apply_lambda_unknown(&'a self, _py: Python, _lambda: &'a PyAny) -> PyResult; - /// Apply a lambda that doesn't change output types - fn apply_lambda(&'a self, _py: Python, _lambda: &'a PyAny) -> PyResult; - // Used to store a struct type fn apply_to_struct( &'a self, @@ -251,11 +245,6 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_with_bool_out_type(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - fn apply_to_struct( &'a self, py: Python, @@ -547,11 +536,6 @@ where .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_with_primitive_out_type::(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - fn apply_to_struct( &'a self, py: Python, @@ -838,11 +822,6 @@ impl<'a> ApplyLambda<'a> for StringChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - let ca = self.apply_lambda_with_string_out_type(py, lambda, 0, None)?; - Ok(ca.into_series().into()) - } - fn apply_to_struct( &'a self, py: Python, @@ -1107,40 +1086,6 @@ impl<'a> ApplyLambda<'a> for StringChunked { } } -fn append_series( - pypolars: &PyModule, - builder: &mut (impl ListBuilderTrait + ?Sized), - lambda: &PyAny, - series: Series, -) -> PyResult<()> { - // create a PySeries struct/object for Python - let pyseries = PySeries::new(series); - // Wrap this PySeries object in the python side Series wrapper - let python_series_wrapper = pypolars - .getattr("wrap_s") - .unwrap() - .call1((pyseries,)) - .unwrap(); - // call the lambda en get a python side Series wrapper - let out = lambda.call1((python_series_wrapper,)); - match out { - Ok(out) => { - // unpack the wrapper in a PySeries - let py_pyseries = out - .getattr("_s") - .expect("could not get Series attribute '_s'"); - let pyseries = py_pyseries.extract::()?; - builder - .append_series(&pyseries.series) - .map_err(PyPolarsErr::from)?; - }, - Err(_) => { - builder.append_opt_series(None).map_err(PyPolarsErr::from)?; - }, - }; - Ok(()) -} - fn call_series_lambda(pypolars: &PyModule, lambda: &PyAny, series: Series) -> Option { // create a PySeries struct/object for Python let pyseries = PySeries::new(series); @@ -1196,74 +1141,6 @@ impl<'a> ApplyLambda<'a> for ListChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - // get the pypolars module - let pypolars = PyModule::import(py, "polars")?; - - match self.dtype() { - DataType::List(dt) => { - let mut builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - if !self.has_validity() { - let mut it = self.into_no_null_iter(); - // use first value to get dtype and replace default builder - if let Some(series) = it.next() { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder - .append_opt_series(Some(&out_series)) - .map_err(PyPolarsErr::from)?; - } else { - let mut builder = - get_list_builder(dt, 0, 1, self.name()).map_err(PyPolarsErr::from)?; - let ca = builder.finish(); - return Ok(PySeries::new(ca.into_series())); - } - for series in it { - append_series(pypolars, &mut *builder, lambda, series)?; - } - } else { - let mut it = self.into_iter(); - let mut nulls = 0; - - // use first values to get dtype and replace default builders - // continue until no null is found - for opt_series in &mut it { - if let Some(series) = opt_series { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder - .append_opt_series(Some(&out_series)) - .map_err(PyPolarsErr::from)?; - break; - } else { - nulls += 1; - } - } - for _ in 0..nulls { - builder.append_opt_series(None).map_err(PyPolarsErr::from)?; - } - for opt_series in it { - if let Some(series) = opt_series { - append_series(pypolars, &mut *builder, lambda, series)?; - } else { - builder.append_opt_series(None).unwrap() - } - } - }; - let ca = builder.finish(); - Ok(PySeries::new(ca.into_series())) - }, - _ => unimplemented!(), - } - } - fn apply_to_struct( &'a self, py: Python, @@ -1679,70 +1556,6 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - // get the pypolars module - let pypolars = PyModule::import(py, "polars")?; - - match self.dtype() { - DataType::List(dt) => { - let mut builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - if !self.has_validity() { - let mut it = self.into_no_null_iter(); - // use first value to get dtype and replace default builder - if let Some(series) = it.next() { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder.append_opt_series(Some(&out_series)); - } else { - let mut builder = - get_list_builder(dt, 0, 1, self.name()).map_err(PyPolarsErr::from)?; - let ca = builder.finish(); - return Ok(PySeries::new(ca.into_series())); - } - for series in it { - append_series(pypolars, &mut *builder, lambda, series)?; - } - } else { - let mut it = self.into_iter(); - let mut nulls = 0; - - // use first values to get dtype and replace default builders - // continue until no null is found - for opt_series in &mut it { - if let Some(series) = opt_series { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder.append_opt_series(Some(&out_series)); - break; - } else { - nulls += 1; - } - } - for _ in 0..nulls { - builder.append_opt_series(None); - } - for opt_series in it { - if let Some(series) = opt_series { - append_series(pypolars, &mut *builder, lambda, series)?; - } else { - builder.append_opt_series(None) - } - } - }; - let ca = builder.finish(); - Ok(PySeries::new(ca.into_series())) - }, - _ => unimplemented!(), - } - } - fn apply_to_struct( &'a self, py: Python, @@ -2149,18 +1962,6 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - #[cfg(feature = "object")] - { - self.apply_lambda_with_object_out_type(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - #[cfg(not(feature = "object"))] - { - todo!() - } - } - fn apply_to_struct( &'a self, _py: Python, @@ -2447,10 +2248,6 @@ impl<'a> ApplyLambda<'a> for StructChunked { Ok(self.clone().into_series().into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_unknown(py, lambda) - } - fn apply_to_struct( &'a self, py: Python, diff --git a/py-polars/src/on_startup.rs b/py-polars/src/on_startup.rs index bf961bb39432..f320d19d4334 100644 --- a/py-polars/src/on_startup.rs +++ b/py-polars/src/on_startup.rs @@ -1,20 +1,17 @@ use std::any::Any; -use std::sync::Arc; use polars::prelude::*; use polars_core::chunked_array::object::builder::ObjectChunkedBuilder; use polars_core::chunked_array::object::registry; use polars_core::chunked_array::object::registry::AnonymousObjectBuilder; use polars_core::error::PolarsError::ComputeError; -use polars_core::error::PolarsResult; -use polars_core::frame::DataFrame; use polars_error::PolarsWarning; use pyo3::intern; use pyo3::prelude::*; use crate::dataframe::PyDataFrame; use crate::map::lazy::{call_lambda_with_series, ToSeries}; -use crate::prelude::{python_udf, ObjectValue}; +use crate::prelude::ObjectValue; use crate::py_modules::{POLARS, UTILS}; use crate::Wrap; @@ -98,7 +95,7 @@ pub fn __register_startup_deps() { unsafe { polars_error::set_warning_function(warning_function) }; Python::with_gil(|py| { // init AnyValue LUT - crate::conversion::anyvalue::LUT + crate::conversion::any_value::LUT .set(py, Default::default()) .unwrap(); }); diff --git a/py-polars/src/series/aggregation.rs b/py-polars/src/series/aggregation.rs index 0b8b6092ed6f..9ed7819d56ac 100644 --- a/py-polars/src/series/aggregation.rs +++ b/py-polars/src/series/aggregation.rs @@ -54,7 +54,7 @@ impl PySeries { .map_err(PyPolarsErr::from)?, ) .into_py(py)), - DataType::Datetime(_, _) => Ok(Wrap( + DataType::Datetime(_, _) | DataType::Duration(_) => Ok(Wrap( self.series .mean_as_series() .get(0) @@ -77,7 +77,7 @@ impl PySeries { .map_err(PyPolarsErr::from)?, ) .into_py(py)), - DataType::Datetime(_, _) => Ok(Wrap( + DataType::Datetime(_, _) | DataType::Duration(_) => Ok(Wrap( self.series .median_as_series() .map_err(PyPolarsErr::from)? diff --git a/py-polars/src/series/c_interface.rs b/py-polars/src/series/c_interface.rs new file mode 100644 index 000000000000..aa87c181cbc3 --- /dev/null +++ b/py-polars/src/series/c_interface.rs @@ -0,0 +1,32 @@ +use polars_rs::export::arrow; +use pyo3::ffi::Py_uintptr_t; + +use super::*; + +// Import arrow data directly without requiring pyarrow (used in pyo3-polars) +#[pymethods] +impl PySeries { + #[staticmethod] + unsafe fn _import_from_c( + name: &str, + chunks: Vec<(Py_uintptr_t, Py_uintptr_t)>, + ) -> PyResult { + let chunks = chunks + .into_iter() + .map(|(schema_ptr, array_ptr)| { + let schema_ptr = schema_ptr as *mut arrow::ffi::ArrowSchema; + let array_ptr = array_ptr as *mut arrow::ffi::ArrowArray; + + // Don't take the box from raw as the other process must deallocate that memory. + let array = std::ptr::read_unaligned(array_ptr); + let schema = &*schema_ptr; + + let field = arrow::ffi::import_field_from_c(schema).unwrap(); + arrow::ffi::import_array_from_c(array, field.data_type).unwrap() + }) + .collect::>(); + + let s = Series::try_from((name, chunks)).map_err(PyPolarsErr::from)?; + Ok(s.into()) + } +} diff --git a/py-polars/src/series/comparison.rs b/py-polars/src/series/comparison.rs index f567fa925a11..c60dbc0e1540 100644 --- a/py-polars/src/series/comparison.rs +++ b/py-polars/src/series/comparison.rs @@ -185,3 +185,59 @@ impl_lt_eq_num!(lt_eq_i64, i64); impl_lt_eq_num!(lt_eq_f32, f32); impl_lt_eq_num!(lt_eq_f64, f64); impl_lt_eq_num!(lt_eq_str, &str); + +struct PyDecimal(i128, usize); + +impl<'source> FromPyObject<'source> for PyDecimal { + fn extract(obj: &'source PyAny) -> PyResult { + if let Ok(val) = obj.extract() { + return Ok(PyDecimal(val, 0)); + } + + let (sign, digits, exponent) = obj + .call_method0("as_tuple")? + .extract::<(i8, Vec, i8)>()?; + let mut val = 0_i128; + for d in digits { + if let Some(v) = val.checked_mul(10).and_then(|val| val.checked_add(d as _)) { + val = v; + } else { + return Err(PyPolarsErr::from(polars_err!(ComputeError: "overflow")).into()); + } + } + let exponent = if exponent > 0 { + if let Some(v) = val.checked_mul(10_i128.pow((-exponent) as u32)) { + val = v; + } else { + return Err(PyPolarsErr::from(polars_err!(ComputeError: "overflow")).into()); + }; + 0_usize + } else { + -exponent as _ + }; + if sign == 1 { + val = -val + }; + Ok(PyDecimal(val, exponent)) + } +} + +macro_rules! impl_decimal { + ($name:ident, $method:ident) => { + #[pymethods] + impl PySeries { + fn $name(&self, rhs: PyDecimal) -> PyResult { + let rhs = Series::new("decimal", &[AnyValue::Decimal(rhs.0, rhs.1)]); + let s = self.series.$method(&rhs).map_err(PyPolarsErr::from)?; + Ok(s.into_series().into()) + } + } + }; +} + +impl_decimal!(eq_decimal, equal); +impl_decimal!(neq_decimal, not_equal); +impl_decimal!(gt_decimal, gt); +impl_decimal!(gt_eq_decimal, gt_eq); +impl_decimal!(lt_decimal, lt); +impl_decimal!(lt_eq_decimal, lt_eq); diff --git a/py-polars/src/series/construction.rs b/py-polars/src/series/construction.rs index 5e159f31a57d..c852be4f7edc 100644 --- a/py-polars/src/series/construction.rs +++ b/py-polars/src/series/construction.rs @@ -184,7 +184,7 @@ init_method_opt!(new_opt_f64, Float64Type, f64); )] impl PySeries { #[staticmethod] - fn new_from_anyvalues( + fn new_from_any_values( name: &str, val: Vec>>, strict: bool, @@ -196,7 +196,7 @@ impl PySeries { } #[staticmethod] - fn new_from_anyvalues_and_dtype( + fn new_from_any_values_and_dtype( name: &str, val: Vec>>, dtype: Wrap, diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index e16fe29ff67a..15d3247a863b 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -1,88 +1,18 @@ +use num_traits::{Float, NumCast}; use numpy::PyArray1; use polars_core::prelude::*; use pyo3::prelude::*; use pyo3::types::PyList; +use crate::conversion::chunked_array::{decimal_to_pyobject_iter, time_to_pyobject_iter}; use crate::error::PyPolarsErr; -use crate::prelude::{ObjectValue, *}; +use crate::prelude::*; use crate::{arrow_interop, raise_err, PySeries}; #[pymethods] impl PySeries { - #[allow(clippy::wrong_self_convention)] - fn to_arrow(&mut self) -> PyResult { - self.rechunk(true); - Python::with_gil(|py| { - let pyarrow = py.import("pyarrow")?; - - arrow_interop::to_py::to_py_array(self.series.to_arrow(0, false), py, pyarrow) - }) - } - - /// For numeric types, this should only be called for Series with null types. - /// Non-nullable types are handled with `view()`. - /// This will cast to floats so that `None = np.nan`. - fn to_numpy(&self, py: Python) -> PyResult { - let s = &self.series; - match s.dtype() { - dt if dt.is_numeric() => { - if s.bit_repr_is_large() { - let s = s.cast(&DataType::Float64).unwrap(); - let ca = s.f64().unwrap(); - let np_arr = PyArray1::from_iter( - py, - ca.into_iter().map(|opt_v| opt_v.unwrap_or(f64::NAN)), - ); - Ok(np_arr.into_py(py)) - } else { - let s = s.cast(&DataType::Float32).unwrap(); - let ca = s.f32().unwrap(); - let np_arr = PyArray1::from_iter( - py, - ca.into_iter().map(|opt_v| opt_v.unwrap_or(f32::NAN)), - ); - Ok(np_arr.into_py(py)) - } - }, - DataType::String => { - let ca = s.str().unwrap(); - let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); - Ok(np_arr.into_py(py)) - }, - DataType::Binary => { - let ca = s.binary().unwrap(); - let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); - Ok(np_arr.into_py(py)) - }, - DataType::Boolean => { - let ca = s.bool().unwrap(); - let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); - Ok(np_arr.into_py(py)) - }, - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let ca = s - .as_any() - .downcast_ref::>() - .unwrap(); - let np_arr = - PyArray1::from_iter(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); - Ok(np_arr.into_py(py)) - }, - DataType::Null => { - let n = s.len(); - let np_arr = PyArray1::from_iter(py, std::iter::repeat(f32::NAN).take(n)); - Ok(np_arr.into_py(py)) - }, - dt => { - raise_err!( - format!("'to_numpy' not supported for dtype: {dt:?}"), - ComputeError - ); - }, - } - } - + /// Convert this Series to a Python list. + /// This operation copies data. pub fn to_list(&self) -> PyObject { Python::with_gil(|py| { let series = &self.series; @@ -217,4 +147,126 @@ impl PySeries { pylist.to_object(py) }) } + + /// Return the underlying Arrow array. + #[allow(clippy::wrong_self_convention)] + fn to_arrow(&mut self) -> PyResult { + self.rechunk(true); + Python::with_gil(|py| { + let pyarrow = py.import("pyarrow")?; + + arrow_interop::to_py::to_py_array(self.series.to_arrow(0, false), py, pyarrow) + }) + } + + /// Convert this Series to a NumPy ndarray. + /// + /// This method will copy data - numeric types without null values should + /// be handled on the Python side in a zero-copy manner. + /// + /// This method will cast integers to floats so that `null = np.nan`. + fn to_numpy(&self, py: Python) -> PyResult { + use DataType::*; + let s = &self.series; + let out = match s.dtype() { + Int8 => numeric_series_to_numpy::(py, s), + Int16 => numeric_series_to_numpy::(py, s), + Int32 => numeric_series_to_numpy::(py, s), + Int64 => numeric_series_to_numpy::(py, s), + UInt8 => numeric_series_to_numpy::(py, s), + UInt16 => numeric_series_to_numpy::(py, s), + UInt32 => numeric_series_to_numpy::(py, s), + UInt64 => numeric_series_to_numpy::(py, s), + Float32 => numeric_series_to_numpy::(py, s), + Float64 => numeric_series_to_numpy::(py, s), + Boolean => { + let ca = s.bool().unwrap(); + let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Date => date_series_to_numpy(py, s), + Datetime(_, _) | Duration(_) => temporal_series_to_numpy(py, s), + Time => { + let ca = s.time().unwrap(); + let iter = time_to_pyobject_iter(py, ca); + let np_arr = PyArray1::from_iter(py, iter.map(|v| v.into_py(py))); + np_arr.into_py(py) + }, + String => { + let ca = s.str().unwrap(); + let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Binary => { + let ca = s.binary().unwrap(); + let np_arr = PyArray1::from_iter(py, ca.into_iter().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Categorical(_, _) | Enum(_, _) => { + let ca = s.categorical().unwrap(); + let np_arr = PyArray1::from_iter(py, ca.iter_str().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Decimal(_, _) => { + let ca = s.decimal().unwrap(); + let iter = decimal_to_pyobject_iter(py, ca); + let np_arr = PyArray1::from_iter(py, iter.map(|v| v.into_py(py))); + np_arr.into_py(py) + }, + #[cfg(feature = "object")] + Object(_, _) => { + let ca = s + .as_any() + .downcast_ref::>() + .unwrap(); + let np_arr = + PyArray1::from_iter(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); + np_arr.into_py(py) + }, + Null => { + let n = s.len(); + let np_arr = PyArray1::from_iter(py, std::iter::repeat(f32::NAN).take(n)); + np_arr.into_py(py) + }, + dt => { + raise_err!( + format!("`to_numpy` not supported for dtype {dt:?}"), + ComputeError + ); + }, + }; + Ok(out) + } +} +/// Convert numeric types to f32 or f64 with NaN representing a null value +fn numeric_series_to_numpy(py: Python, s: &Series) -> PyObject +where + T: PolarsNumericType, + U: Float + numpy::Element, +{ + let ca: &ChunkedArray = s.as_ref().as_ref(); + let mapper = |opt_v: Option| match opt_v { + Some(v) => NumCast::from(v).unwrap(), + None => U::nan(), + }; + let np_arr = PyArray1::from_iter(py, ca.iter().map(mapper)); + np_arr.into_py(py) +} +/// Convert dates directly to i64 with i64::MIN representing a null value +fn date_series_to_numpy(py: Python, s: &Series) -> PyObject { + let s_phys = s.to_physical_repr(); + let ca = s_phys.i32().unwrap(); + let mapper = |opt_v: Option| match opt_v { + Some(v) => v as i64, + None => i64::MIN, + }; + let np_arr = PyArray1::from_iter(py, ca.iter().map(mapper)); + np_arr.into_py(py) +} +/// Convert datetimes and durations with i64::MIN representing a null value +fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject { + let s_phys = s.to_physical_repr(); + let ca = s_phys.i64().unwrap(); + let np_arr = PyArray1::from_iter(py, ca.iter().map(|v| v.unwrap_or(i64::MIN))); + np_arr.into_py(py) } diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 5e54945836aa..a523ee6439ea 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -1,6 +1,7 @@ mod aggregation; mod arithmetic; mod buffers; +mod c_interface; mod comparison; mod construction; mod export; @@ -49,7 +50,7 @@ pub(crate) trait ToSeries { impl ToSeries for Vec { fn to_series(self) -> Vec { - // Safety + // SAFETY: // repr is transparent unsafe { std::mem::transmute(self) } } @@ -61,7 +62,7 @@ pub(crate) trait ToPySeries { impl ToPySeries for Vec { fn to_pyseries(self) -> Vec { - // Safety + // SAFETY: // repr is transparent unsafe { std::mem::transmute(self) } } @@ -388,7 +389,8 @@ impl PySeries { ) || !skip_nulls { let mut avs = Vec::with_capacity(self.series.len()); - let iter = self.series.iter().map(|av| match (skip_nulls, av) { + let s = self.series.rechunk(); + let iter = s.iter().map(|av| match (skip_nulls, av) { (true, AnyValue::Null) => AnyValue::Null, (_, av) => { let input = Wrap(av); @@ -703,6 +705,11 @@ impl PySeries { let length = length.unwrap_or_else(|| self.series.len()); self.series.slice(offset, length).into() } + + pub fn not_(&self) -> PyResult { + let out = polars_ops::series::negate_bitwise(&self.series).map_err(PyPolarsErr::from)?; + Ok(out.into()) + } } macro_rules! impl_set_with_mask { diff --git a/py-polars/src/to_numpy.rs b/py-polars/src/to_numpy.rs new file mode 100644 index 000000000000..53bb9ff014fa --- /dev/null +++ b/py-polars/src/to_numpy.rs @@ -0,0 +1,189 @@ +use std::ffi::{c_int, c_void}; + +use ndarray::{Dim, Dimension, IntoDimension}; +use numpy::npyffi::{flags, PyArrayObject}; +use numpy::{npyffi, Element, IntoPyArray, ToNpyDims, PY_ARRAY_API}; +use polars_core::prelude::*; +use polars_core::utils::try_get_supertype; +use polars_core::with_match_physical_numeric_polars_type; +use pyo3::prelude::*; + +use crate::conversion::Wrap; +use crate::dataframe::PyDataFrame; +use crate::series::PySeries; + +pub(crate) unsafe fn create_borrowed_np_array( + py: Python, + mut shape: Dim, + flags: c_int, + data: *mut c_void, + owner: PyObject, +) -> PyObject +where + Dim: Dimension + ToNpyDims, +{ + // See: https://numpy.org/doc/stable/reference/c-api/array.html + let array = PY_ARRAY_API.PyArray_NewFromDescr( + py, + PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type), + T::get_dtype(py).into_dtype_ptr(), + shape.ndim_cint(), + shape.as_dims_ptr(), + // We don't provide strides, but provide flags that tell c/f-order + std::ptr::null_mut(), + data, + flags, + std::ptr::null_mut(), + ); + + // This keeps the memory alive + let owner_ptr = owner.as_ptr(); + // SetBaseObject steals a reference + // so we can forget. + std::mem::forget(owner); + PY_ARRAY_API.PyArray_SetBaseObject(py, array as *mut PyArrayObject, owner_ptr); + + let any: &PyAny = py.from_owned_ptr(array); + any.into_py(py) +} + +#[pymethods] +impl PySeries { + /// Create a view of the data as a NumPy ndarray. + /// + /// WARNING: The resulting view will show the underlying value for nulls, + /// which may be any value. The caller is responsible for handling nulls + /// appropriately. + #[allow(clippy::wrong_self_convention)] + pub fn to_numpy_view(&self, py: Python) -> Option { + // NumPy arrays are always contiguous + if self.series.n_chunks() > 1 { + return None; + } + + match self.series.dtype() { + dt if dt.is_numeric() => { + let dims = [self.series.len()].into_dimension(); + // Object to the series keep the memory alive. + let owner = self.clone().into_py(py); + with_match_physical_numeric_polars_type!(self.series.dtype(), |$T| { + let ca: &ChunkedArray<$T> = self.series.unpack::<$T>().unwrap(); + let slice = ca.data_views().next().unwrap(); + let view = unsafe { + create_borrowed_np_array::<<$T as PolarsNumericType>::Native, _>( + py, + dims, + flags::NPY_ARRAY_FARRAY_RO, + slice.as_ptr() as _, + owner, + ) + }; + Some(view) + }) + }, + _ => None, + } + } +} + +#[pymethods] +#[allow(clippy::wrong_self_convention)] +impl PyDataFrame { + pub fn to_numpy_view(&self, py: Python) -> Option { + if self.df.is_empty() { + return None; + } + let first = self.df.get_columns().first().unwrap().dtype(); + if !first.is_numeric() { + return None; + } + if !self + .df + .get_columns() + .iter() + .all(|s| s.null_count() == 0 && s.dtype() == first && s.chunks().len() == 1) + { + return None; + } + + // Object to the dataframe keep the memory alive. + let owner = self.clone().into_py(py); + + fn get_ptr( + py: Python, + columns: &[Series], + owner: PyObject, + ) -> Option + where + T::Native: Element, + { + let slices = columns + .iter() + .map(|s| { + let ca: &ChunkedArray = s.unpack().unwrap(); + ca.cont_slice().unwrap() + }) + .collect::>(); + + let first = slices.first().unwrap(); + unsafe { + let mut end_ptr = first.as_ptr().add(first.len()); + // Check if all arrays are from the same buffer + let all_contiguous = slices[1..].iter().all(|slice| { + let valid = slice.as_ptr() == end_ptr; + + end_ptr = slice.as_ptr().add(slice.len()); + + valid + }); + + if all_contiguous { + let start_ptr = first.as_ptr(); + let dims = [first.len(), columns.len()].into_dimension(); + Some(create_borrowed_np_array::( + py, + dims, + flags::NPY_ARRAY_FARRAY_RO, + start_ptr as _, + owner, + )) + } else { + None + } + } + } + with_match_physical_numeric_polars_type!(first, |$T| { + get_ptr::<$T>(py, self.df.get_columns(), owner) + }) + } + + pub fn to_numpy(&self, py: Python, order: Wrap) -> Option { + let mut st = None; + for s in self.df.iter() { + let dt_i = s.dtype(); + match st { + None => st = Some(dt_i.clone()), + Some(ref mut st) => { + *st = try_get_supertype(st, dt_i).ok()?; + }, + } + } + let st = st?; + + #[rustfmt::skip] + let pyarray = match st { + DataType::UInt8 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Int8 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::UInt16 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Int16 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::UInt32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::UInt64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Int32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Int64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Float32 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + DataType::Float64 => self.df.to_ndarray::(order.0).ok()?.into_pyarray(py).into_py(py), + _ => return None, + }; + Some(pyarray) + } +} diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 65282ea70cfa..b4fffcca14fe 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -43,6 +43,16 @@ from types import ModuleType +if sys.version_info < (3, 12): + # Tests that print an OrderedDict fail (e.g. DataFrame.schema) as the repr + # has changed in Python 3.12 + warnings.warn( + "Certain doctests may fail when running on a Python version below 3.12." + " Update your Python version to 3.12 or later to make sure all tests pass.", + stacklevel=2, + ) + + def doctest_teardown(d: doctest.DocTest) -> None: # don't let config changes or string cache state leak between tests polars.Config.restore_defaults() diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py index 032961dd936a..3b17f7196c77 100644 --- a/py-polars/tests/docs/test_user_guide.py +++ b/py-polars/tests/docs/test_user_guide.py @@ -15,11 +15,14 @@ python_snippets_dir = repo_root / "docs" / "src" / "python" snippet_paths = list(python_snippets_dir.rglob("*.py")) +# Skip visualization snippets +snippet_paths = [p for p in snippet_paths if "visualization" not in str(p)] + @pytest.fixture(scope="module") def _change_test_dir() -> Iterator[None]: """Change path to repo root to accommodate data paths in code snippets.""" - current_path = Path() + current_path = Path().resolve() os.chdir(repo_root) yield os.chdir(current_path) diff --git a/py-polars/tests/parametric/test_series.py b/py-polars/tests/parametric/test_series.py index 27d4062afe76..4dedafaf888b 100644 --- a/py-polars/tests/parametric/test_series.py +++ b/py-polars/tests/parametric/test_series.py @@ -3,95 +3,13 @@ # ------------------------------------------------- from __future__ import annotations -from typing import Any - from hypothesis import given, settings -from hypothesis.strategies import booleans, floats, sampled_from +from hypothesis.strategies import sampled_from import polars as pl -from polars.expr.expr import _prepare_alpha from polars.testing import assert_series_equal from polars.testing.parametric import series - -def alpha_guard(**decay_param: float) -> bool: - """Protects against unnecessary noise in small number regime.""" - if not next(iter(decay_param.values())): - return True - alpha = _prepare_alpha(**decay_param) - return ((1 - alpha) if round(alpha) else alpha) > 1e-6 - - -@given( - s=series( - min_size=4, - dtype=pl.Float64, - null_probability=0.05, - strategy=floats(min_value=-1e8, max_value=1e8), - ), - half_life=floats(min_value=0, max_value=4, exclude_min=True).filter( - lambda x: alpha_guard(half_life=x) - ), - com=floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), - span=floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), - ignore_nulls=booleans(), - adjust=booleans(), - bias=booleans(), -) -def test_ewm_methods( - s: pl.Series, - com: float | None, - span: float | None, - half_life: float | None, - ignore_nulls: bool, - adjust: bool, - bias: bool, -) -> None: - # validate a large set of varied EWM calculations - for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]: - alpha = _prepare_alpha(**decay_param) - - # convert parametrically-generated series to pandas, then use that as a - # reference implementation for comparison (after normalising NaN/None) - p = s.to_pandas() - - # note: skip min_periods < 2, due to pandas-side inconsistency: - # https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178 - for mp in range(2, len(s), len(s) // 3): - # consolidate ewm parameters - pl_params: dict[str, Any] = { - "min_periods": mp, - "adjust": adjust, - "ignore_nulls": ignore_nulls, - } - pl_params.update(decay_param) - pd_params = pl_params.copy() - if "half_life" in pl_params: - pd_params["halflife"] = pd_params.pop("half_life") - if "ignore_nulls" in pl_params: - pd_params["ignore_na"] = pd_params.pop("ignore_nulls") - - # mean: - ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None) - ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean()) - if alpha == 1: - # apply fill-forward to nulls to match pandas - # https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124 - ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward") - - assert_series_equal(ewm_mean_pl, ewm_mean_pd, atol=1e-07) - - # std: - ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None) - ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias)) - assert_series_equal(ewm_std_pl, ewm_std_pd, atol=1e-07) - - # var: - ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None) - ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias)) - assert_series_equal(ewm_var_pl, ewm_var_pd, atol=1e-07) - - # TODO: once Decimal is a little further along, start actively probing it # @given( # s=series(max_size=10, dtype=pl.Decimal, null_probability=0.1), diff --git a/py-polars/tests/parametric/test_testing.py b/py-polars/tests/parametric/test_testing.py index 6b47c43f35c4..a55b42d9ede5 100644 --- a/py-polars/tests/parametric/test_testing.py +++ b/py-polars/tests/parametric/test_testing.py @@ -210,7 +210,7 @@ def finite_float(value: Any) -> bool: @given( df=dataframes( cols=[ - column("colx", dtype=pl.List(pl.UInt8)), + column("colx", dtype=pl.Array(pl.UInt8, width=3)), column("coly", dtype=pl.List(pl.Datetime("ms"))), column( name="colz", @@ -223,15 +223,16 @@ def finite_float(value: Any) -> bool: ] ), ) -def test_list_strategy(df: pl.DataFrame) -> None: +def test_sequence_strategies(df: pl.DataFrame) -> None: assert df.schema == { - "colx": pl.List(pl.UInt8), + "colx": pl.Array(pl.UInt8, width=3), "coly": pl.List(pl.Datetime("ms")), "colz": pl.List(pl.List(pl.String)), } uint8_max = (2**8) - 1 for colx, coly, colz in df.iter_rows(): + assert len(colx) == 3 assert all(i <= uint8_max for i in colx) assert all(isinstance(d, datetime) for d in coly) for inner_list in colz: diff --git a/py-polars/tests/parametric/time_series/test_to_datetime.py b/py-polars/tests/parametric/time_series/test_to_datetime.py index 6e097bec5477..65785c60fe86 100644 --- a/py-polars/tests/parametric/time_series/test_to_datetime.py +++ b/py-polars/tests/parametric/time_series/test_to_datetime.py @@ -6,11 +6,13 @@ import polars as pl from polars.exceptions import ComputeError from polars.testing.parametric.strategies import strategy_datetime_format +from polars.type_aliases import TimeUnit @given( datetimes=st.datetimes( - min_value=datetime(2000, 1, 1), max_value=datetime(9999, 12, 31) + min_value=datetime(1699, 1, 1), + max_value=datetime(9999, 12, 31), ), fmt=strategy_datetime_format(), ) @@ -42,3 +44,27 @@ def test_to_datetime(datetimes: datetime, fmt: str) -> None: ) else: assert result == expected + + +@given( + d=st.datetimes( + min_value=datetime(1699, 1, 1), + max_value=datetime(9999, 12, 31), + ), + tu=st.sampled_from(["ms", "us"]), +) +def test_cast_to_time_and_combine(d: datetime, tu: TimeUnit) -> None: + # round-trip date/time extraction + recombining + df = pl.DataFrame({"d": [d]}, schema={"d": pl.Datetime(tu)}) + res = df.select( + d=pl.col("d"), + dt=pl.col("d").dt.date(), + tm=pl.col("d").cast(pl.Time), + ).with_columns( + dtm=pl.col("dt").dt.combine(pl.col("tm")), + ) + + datetimes = res["d"].to_list() + assert [d.date() for d in datetimes] == res["dt"].to_list() + assert [d.time() for d in datetimes] == res["tm"].to_list() + assert datetimes == res["dtm"].to_list() diff --git a/py-polars/tests/parametric/time_series/test_truncate.py b/py-polars/tests/parametric/time_series/test_truncate.py index 80dce97fb9a0..6e684ce130ad 100644 --- a/py-polars/tests/parametric/time_series/test_truncate.py +++ b/py-polars/tests/parametric/time_series/test_truncate.py @@ -8,7 +8,8 @@ @given( value=st.datetimes( - min_value=dt.datetime(1000, 1, 1), max_value=dt.datetime(3000, 1, 1) + min_value=dt.datetime(1000, 1, 1), + max_value=dt.datetime(3000, 1, 1), ), n=st.integers(min_value=1, max_value=100), ) diff --git a/py-polars/tests/unit/constructors/__init__.py b/py-polars/tests/unit/constructors/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py new file mode 100644 index 000000000000..bff14e5a461e --- /dev/null +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -0,0 +1,73 @@ +# TODO: Replace direct calls to fallback constructors with calls to the Series +# constructor once the Python-side logic has been updated +from __future__ import annotations + +from datetime import date +from typing import Any + +import pytest + +import polars as pl +from polars.polars import PySeries +from polars.utils._wrap import wrap_s + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Boolean, [True, False, None]), + (pl.Binary, [b"123", b"xyz", None]), + (pl.String, ["123", "xyz", None]), + ], +) +def test_fallback_with_dtype_strict( + dtype: pl.PolarsDataType, values: list[Any] +) -> None: + result = wrap_s( + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + ) + assert result.to_list() == values + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Boolean, [0, 1]), + (pl.Binary, ["123", "xyz"]), + (pl.String, [b"123", b"xyz"]), + ], +) +def test_fallback_with_dtype_strict_failure( + dtype: pl.PolarsDataType, values: list[Any] +) -> None: + with pytest.raises(pl.SchemaError, match="unexpected value"): + PySeries.new_from_any_values_and_dtype("", values, pl.Boolean, strict=True) + + +@pytest.mark.parametrize( + ("dtype", "values", "expected"), + [ + ( + pl.Boolean, + [False, True, 0, 1, 0.0, 2.5, date(1970, 1, 1)], + [False, True, False, True, False, True, None], + ), + ( + pl.Binary, + [b"123", "xyz", 100, True, None], + [b"123", b"xyz", None, None, None], + ), + ( + pl.String, + ["xyz", 1, 2.5, date(1970, 1, 1), True, b"123", None], + ["xyz", "1", "2.5", "1970-01-01", "true", None, None], + ), + ], +) +def test_fallback_with_dtype_nonstrict( + dtype: pl.PolarsDataType, values: list[Any], expected: list[Any] +) -> None: + result = wrap_s( + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=False) + ) + assert result.to_list() == expected diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py similarity index 94% rename from py-polars/tests/unit/test_constructors.py rename to py-polars/tests/unit/constructors/test_constructors.py index 01a4de503444..5e836cb0775b 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -2,7 +2,7 @@ import sys from collections import OrderedDict, namedtuple -from datetime import date, datetime, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from random import shuffle from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple @@ -14,12 +14,15 @@ from pydantic import BaseModel, Field, TypeAdapter import polars as pl +from polars.datatypes import PolarsDataType, numpy_char_code_to_dtype from polars.dependencies import _ZONEINFO_AVAILABLE, dataclasses, pydantic from polars.exceptions import TimeZoneAwareConstructorWarning from polars.testing import assert_frame_equal, assert_series_equal from polars.utils._construction import type_hints if TYPE_CHECKING: + from collections.abc import Callable + from polars.datatypes import PolarsDataType if sys.version_info >= (3, 9): @@ -797,6 +800,45 @@ def test_init_series() -> None: assert_series_equal(s5, pl.Series("", [1, 2, 3], dtype=pl.Int8)) +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (int, pl.Int64), + (bytes, pl.Binary), + (float, pl.Float64), + (str, pl.String), + (date, pl.Date), + (time, pl.Time), + (datetime, pl.Datetime("us")), + (timedelta, pl.Duration("us")), + (Decimal, pl.Decimal(precision=None, scale=0)), + ], +) +def test_init_py_dtype(dtype: Any, expected_dtype: PolarsDataType) -> None: + for s in ( + pl.Series("s", [None], dtype=dtype), + pl.Series("s", [], dtype=dtype), + ): + assert s.dtype == expected_dtype + + for df in ( + pl.DataFrame({"col": [None]}, schema={"col": dtype}), + pl.DataFrame({"col": []}, schema={"col": dtype}), + ): + assert df.schema == {"col": expected_dtype} + + +def test_init_py_dtype_misc_float() -> None: + assert pl.Series([100], dtype=float).dtype == pl.Float64 # type: ignore[arg-type] + + df = pl.DataFrame( + {"x": [100.0], "y": [200], "z": [None]}, + schema={"x": float, "y": float, "z": float}, + ) + assert df.schema == {"x": pl.Float64, "y": pl.Float64, "z": pl.Float64} + assert df.rows() == [(100.0, 200.0, None)] + + def test_init_seq_of_seq() -> None: # List of lists df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"]) @@ -1161,6 +1203,11 @@ def test_from_rows_dtype() -> None: assert df.dtypes == [pl.Int32, pl.Object, pl.Object] assert df.null_count().row(0) == (0, 0, 0) + dc = _TestBazDC(d=datetime(2020, 2, 22), e=42.0, f="xyz") + df = pl.DataFrame([[dc]], schema={"d": pl.Object}) + assert df.schema == {"d": pl.Object} + assert df.item() == dc + def test_from_dicts_schema() -> None: data = [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}] @@ -1513,3 +1560,44 @@ def test_df_init_dict_raise_on_expression_input() -> None: # Passing a list of expressions is allowed df = pl.DataFrame({"a": [pl.int_range(0, 3)]}) assert df.get_column("a").dtype == pl.Object + + +def test_df_schema_sequences() -> None: + schema = [ + ["address", pl.String], + ["key", pl.Int64], + ["value", pl.Float32], + ] + df = pl.DataFrame(schema=schema) # type: ignore[arg-type] + assert df.schema == {"address": pl.String, "key": pl.Int64, "value": pl.Float32} + + +def test_df_schema_sequences_incorrect_length() -> None: + schema = [ + ["address", pl.String, pl.Int8], + ["key", pl.Int64], + ["value", pl.Float32], + ] + with pytest.raises(ValueError): + pl.DataFrame(schema=schema) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("input", "infer_func", "expected_dtype"), + [ + ("f8", numpy_char_code_to_dtype, pl.Float64), + ("f4", numpy_char_code_to_dtype, pl.Float32), + ("i4", numpy_char_code_to_dtype, pl.Int32), + ("u1", numpy_char_code_to_dtype, pl.UInt8), + ("?", numpy_char_code_to_dtype, pl.Boolean), + ("m8", numpy_char_code_to_dtype, pl.Duration("us")), + ("M8", numpy_char_code_to_dtype, pl.Datetime("us")), + ], +) +def test_numpy_inference( + input: Any, + infer_func: Callable[[Any], PolarsDataType], + expected_dtype: PolarsDataType, +) -> None: + result = infer_func(input) + assert result == expected_dtype diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 702088ea4161..339c2a6ab038 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -13,7 +13,6 @@ import numpy as np import pyarrow as pa import pytest -from numpy.testing import assert_array_equal, assert_equal import polars as pl import polars.selectors as cs @@ -28,7 +27,7 @@ from polars.utils._construction import iterable_to_pydf if TYPE_CHECKING: - from polars.type_aliases import IndexOrder, JoinStrategy, UniqueKeepStrategy + from polars.type_aliases import JoinStrategy, UniqueKeepStrategy if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -96,6 +95,21 @@ def test_comparisons() -> None: assert_frame_equal( df == other, pl.DataFrame({"a": [True, True], "b": [False, False]}) ) + assert_frame_equal( + df != other, pl.DataFrame({"a": [False, False], "b": [True, True]}) + ) + assert_frame_equal( + df > other, pl.DataFrame({"a": [False, False], "b": [True, True]}) + ) + assert_frame_equal( + df < other, pl.DataFrame({"a": [False, False], "b": [False, False]}) + ) + assert_frame_equal( + df >= other, pl.DataFrame({"a": [True, True], "b": [True, True]}) + ) + assert_frame_equal( + df <= other, pl.DataFrame({"a": [True, True], "b": [False, False]}) + ) # DataFrame columns mismatch with pytest.raises(ValueError): @@ -244,6 +258,15 @@ def test_from_arrow(monkeypatch: Any) -> None: assert df.schema == expected_schema assert df.rows() == expected_data + # record batches (inc. empty) + for b, n_expected in ( + (record_batches[0], 1), + (record_batches[0][:0], 0), + ): + df = cast(pl.DataFrame, pl.from_arrow(b)) + assert df.schema == expected_schema + assert df.rows() == expected_data[:n_expected] + empty_tbl = tbl[:0] # no rows df = cast(pl.DataFrame, pl.from_arrow(empty_tbl)) assert df.schema == expected_schema @@ -381,6 +404,9 @@ def test_to_series() -> None: assert_series_equal(df.to_series(2), df["z"]) assert_series_equal(df.to_series(-1), df["z"]) + with pytest.raises(TypeError, match="should be an int"): + df.to_series("x") # type: ignore[arg-type] + def test_gather_every() -> None: df = pl.DataFrame({"a": [1, 2, 3, 4], "b": ["w", "x", "y", "z"]}) @@ -971,97 +997,6 @@ def test_assign() -> None: assert list(df["a"]) == [2, 4, 6] -@pytest.mark.parametrize( - ("order", "f_contiguous", "c_contiguous"), - [("fortran", True, False), ("c", False, True)], -) -def test_to_numpy(order: IndexOrder, f_contiguous: bool, c_contiguous: bool) -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) - - out_array = df.to_numpy(order=order) - expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) - assert_array_equal(out_array, expected_array) - assert out_array.flags["F_CONTIGUOUS"] == f_contiguous - assert out_array.flags["C_CONTIGUOUS"] == c_contiguous - - structured_array = df.to_numpy(structured=True, order=order) - expected_array = np.array( - [(1, 1.0), (2, 2.0), (3, 3.0)], dtype=[("a", " None: - # round-trip structured array: validate init/export - structured_array = np.array( - [ - ("Google Pixel 7", 521.90, True), - ("Apple iPhone 14 Pro", 999.00, True), - ("OnePlus 11", 699.00, True), - ("Samsung Galaxy S23 Ultra", 1199.99, False), - ], - dtype=np.dtype( - [ - ("product", "U24"), - ("price_usd", "float64"), - ("in_stock", "bool"), - ] - ), - ) - df = pl.from_numpy(structured_array) - assert df.schema == { - "product": pl.String, - "price_usd": pl.Float64, - "in_stock": pl.Boolean, - } - exported_array = df.to_numpy(structured=True) - assert exported_array["product"].dtype == np.dtype("U24") - assert_array_equal(exported_array, structured_array) - - # none/nan values - df = pl.DataFrame({"x": ["a", None, "b"], "y": [5.5, None, -5.5]}) - exported_array = df.to_numpy(structured=True) - - assert exported_array.dtype == np.dtype([("x", object), ("y", float)]) - for name in df.columns: - assert_equal( - list(exported_array[name]), - ( - df[name].fill_null(float("nan")) - if df.schema[name].is_float() - else df[name] - ).to_list(), - ) - - -def test__array__() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) - - out_array = np.asarray(df.to_numpy()) - expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) - assert_array_equal(out_array, expected_array) - assert out_array.flags["F_CONTIGUOUS"] is True - - out_array = np.asarray(df.to_numpy(), np.uint8) - expected_array = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.uint8) - assert_array_equal(out_array, expected_array) - assert out_array.flags["F_CONTIGUOUS"] is True - - def test_arg_sort_by(df: pl.DataFrame) -> None: idx_df = df.select( pl.arg_sort_by(["int_nulls", "floats"], descending=[False, True]).alias("idx") @@ -1123,32 +1058,6 @@ def test_literal_series() -> None: ) -def test_to_html() -> None: - # check it does not panic/error, and appears to contain - # a reasonable table with suitably escaped html entities. - df = pl.DataFrame( - { - "foo": [1, 2, 3], - "": ["a", "b", "c"], - "": ["a", "b", "c"], - } - ) - html = df._repr_html_() - for match in ( - "foo", - "<bar>", - "<baz", - "spam>", - "1", - "2", - "3", - ): - assert match in html, f"Expected to find {match!r} in html repr" - - def test_rename(df: pl.DataFrame) -> None: out = df.rename({"strings": "bars", "int": "foos"}) # check if we can select these new columns @@ -2226,6 +2135,12 @@ def test_getitem() -> None: with pytest.raises(TypeError): _ = df[np.array([1.0])] + with pytest.raises( + TypeError, + match="multi-dimensional NumPy arrays not supported", + ): + df[np.array([[0], [1]])] + # sequences (lists or tuples; tuple only if length != 2) # if strings or list of expressions, assumed to be column names # if bools, assumed to be a row mask @@ -2276,6 +2191,13 @@ def test_getitem() -> None: with pytest.raises(TypeError): df[pl.Series([True, False, True]), "b"] + # wrong length boolean mask for column selection + with pytest.raises( + ValueError, + match=f"expected {df.width} values when selecting columns by boolean mask", + ): + df[:, [True, False, True]] + # 5343 df = pl.DataFrame( { @@ -2324,6 +2246,7 @@ def test_product() -> None: "flt": [-1.0, 12.0, 9.0], "bool_0": [True, False, True], "bool_1": [True, True, True], + "str": ["a", "b", "c"], }, schema_overrides={ "int": pl.UInt16, @@ -2331,7 +2254,9 @@ def test_product() -> None: }, ) out = df.product() - expected = pl.DataFrame({"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1]}) + expected = pl.DataFrame( + {"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1], "str": [None]} + ) assert_frame_not_equal(out, expected, check_dtype=True) assert_frame_equal(out, expected, check_dtype=False) @@ -2868,7 +2793,7 @@ def test_init_datetimes_with_timezone() -> None: tz_europe = "Europe/Amsterdam" dtm = datetime(2022, 10, 12, 12, 30) - for time_unit in DTYPE_TEMPORAL_UNITS | frozenset([None]): + for time_unit in DTYPE_TEMPORAL_UNITS: for type_overrides in ( { "schema": [ @@ -2970,7 +2895,7 @@ def test_init_physical_with_timezone() -> None: tz_asia = "Asia/Tokyo" dtm_us = 1665577800000000 - for time_unit in DTYPE_TEMPORAL_UNITS | frozenset([None]): + for time_unit in DTYPE_TEMPORAL_UNITS: dtm = {"ms": dtm_us // 1_000, "ns": dtm_us * 1_000}.get(str(time_unit), dtm_us) df = pl.DataFrame( data={"d1": [dtm], "d2": [dtm]}, @@ -3117,81 +3042,6 @@ def test_dot() -> None: assert df.select(pl.col("a").dot(pl.col("b"))).item() == 12.96 -def test_ufunc() -> None: - df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) - out = df.select( - [ - np.power(pl.col("a"), 2).alias("power_uint8"), # type: ignore[call-overload] - np.power(pl.col("a"), 2.0).alias("power_float64"), # type: ignore[call-overload] - np.power(pl.col("a"), 2, dtype=np.uint16).alias("power_uint16"), # type: ignore[call-overload] - ] - ) - expected = pl.DataFrame( - [ - pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), - pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), - ] - ) - assert_frame_equal(out, expected) - assert out.dtypes == expected.dtypes - - -def test_ufunc_expr_not_first() -> None: - """Check numpy ufunc expressions also work if expression not the first argument.""" - df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) - out = df.select( - [ - np.power(2.0, cast(Any, pl.col("a"))).alias("power"), - (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), - (np.array([2, 2, 2]) / cast(Any, pl.col("a"))).alias("divide_array"), - ] - ) - expected = pl.DataFrame( - [ - pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), - pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), - pl.Series("divide_array", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), - ] - ) - assert_frame_equal(out, expected) - - -def test_ufunc_multiple_expressions() -> None: - # example from https://github.com/pola-rs/polars/issues/6770 - df = pl.DataFrame( - { - "v": [ - -4.293, - -2.4659, - -1.8378, - -0.2821, - -4.5649, - -3.8128, - -7.4274, - 3.3443, - 3.8604, - -4.2200, - ], - "u": [ - -11.2268, - 6.3478, - 7.1681, - 3.4986, - 2.7320, - -1.0695, - -10.1408, - 11.2327, - 6.6623, - -8.1412, - ], - } - ) - expected = np.arctan2(df.get_column("v"), df.get_column("u")) - result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload] - assert_series_equal(expected, result) # type: ignore[arg-type] - - def test_unstack() -> None: from string import ascii_uppercase diff --git a/py-polars/tests/unit/dataframe/test_glimpse.py b/py-polars/tests/unit/dataframe/test_glimpse.py index 67a98bf46ac6..022bf7205d76 100644 --- a/py-polars/tests/unit/dataframe/test_glimpse.py +++ b/py-polars/tests/unit/dataframe/test_glimpse.py @@ -61,7 +61,7 @@ def test_glimpse(capsys: Any) -> None: assert result == expected # the default is to print to the console - df.glimpse(return_as_string=False) + df.glimpse() # remove the last newline on the capsys assert capsys.readouterr().out[:-1] == expected diff --git a/py-polars/tests/unit/dataframe/test_repr_html.py b/py-polars/tests/unit/dataframe/test_repr_html.py new file mode 100644 index 000000000000..8e7a62a6efc2 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_repr_html.py @@ -0,0 +1,79 @@ +import polars as pl + + +def test_repr_html() -> None: + # check it does not panic/error, and appears to contain + # a reasonable table with suitably escaped html entities. + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "": ["a", "b", "c"], + "": ["a", "b", "c"], + } + ) + html = df._repr_html_() + for match in ( + "foo", + "<bar>", + "<baz", + "spam>", + "1", + "2", + "3", + ): + assert match in html, f"Expected to find {match!r} in html repr" + + +def test_html_tables() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + # default: header contains names/dtypes + header = "abci64i64i64" + assert header in df._repr_html_() + + # validate that relevant config options are respected + with pl.Config(tbl_hide_column_names=True): + header = "i64i64i64" + assert header in df._repr_html_() + + with pl.Config(tbl_hide_column_data_types=True): + header = "abc" + assert header in df._repr_html_() + + with pl.Config( + tbl_hide_column_data_types=True, + tbl_hide_column_names=True, + ): + header = "" + assert header in df._repr_html_() + + +def test_df_repr_html_max_rows_default() -> None: + df = pl.DataFrame({"a": range(50)}) + + html = df._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows + + +def test_df_repr_html_max_rows_odd() -> None: + df = pl.DataFrame({"a": range(50)}) + + with pl.Config(tbl_rows=9): + html = df._repr_html_() + + expected_rows = 9 + assert html.count("") - 2 == expected_rows + + +def test_series_repr_html_max_rows_default() -> None: + s = pl.Series("a", range(50)) + + html = s._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 30f226115a0c..c532494676d0 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -1,11 +1,12 @@ import datetime +from datetime import timedelta from typing import Any import pytest import polars as pl from polars.exceptions import InvalidOperationError -from polars.testing import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_cast_list_array() -> None: @@ -207,6 +208,94 @@ def test_cast_list_to_array(data: Any, inner_type: pl.DataType) -> None: assert s.to_list() == data +@pytest.fixture() +def data_dispersion() -> pl.DataFrame: + return pl.DataFrame( + { + "int": [[1, 2, 3, 4, 5]], + "float": [[1.0, 2.0, 3.0, 4.0, 5.0]], + "duration": [[1000, 2000, 3000, 4000, 5000]], + }, + schema={ + "int": pl.Array(pl.Int64, 5), + "float": pl.Array(pl.Float64, 5), + "duration": pl.Array(pl.Duration, 5), + }, + ) + + +def test_arr_var(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.var().name.suffix("_var"), + pl.col("float").arr.var().name.suffix("_var"), + pl.col("duration").arr.var().name.suffix("_var"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_var", [2.5], dtype=pl.Float64), + pl.Series("float_var", [2.5], dtype=pl.Float64), + pl.Series( + "duration_var", + [timedelta(microseconds=2000)], + dtype=pl.Duration(time_unit="ms"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_arr_std(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.std().name.suffix("_std"), + pl.col("float").arr.std().name.suffix("_std"), + pl.col("duration").arr.std().name.suffix("_std"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series("float_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series( + "duration_std", + [timedelta(microseconds=1581)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_arr_median(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.median().name.suffix("_median"), + pl.col("float").arr.median().name.suffix("_median"), + pl.col("duration").arr.median().name.suffix("_median"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_median", [3.0], dtype=pl.Float64), + pl.Series("float_median", [3.0], dtype=pl.Float64), + pl.Series( + "duration_median", + [timedelta(microseconds=3000)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + def test_array_repeat() -> None: dtype = pl.Array(pl.UInt8, width=1) s = pl.repeat([42], n=3, dtype=dtype, eager=True) diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index 4e02decb8fe9..3ffabdc02d17 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -574,9 +574,9 @@ def test_nested_categorical_aggregation_7848() -> None: "letter": ["a", "b", "c", "d", "e", "f", "g"], } ).with_columns([pl.col("letter").cast(pl.Categorical)]).group_by( - maintain_order=True, by=["group"] + "group", maintain_order=True ).all().with_columns(pl.col("letter").list.len().alias("c_group")).group_by( - by=["c_group"], maintain_order=True + ["c_group"], maintain_order=True ).agg(pl.col("letter")).to_dict(as_series=False) == { "c_group": [2, 3], "letter": [[["a", "b"], ["f", "g"]], [["c", "d", "e"]]], diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index df9f17b52b21..1c125de7a2eb 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -2,16 +2,15 @@ import io import itertools +import operator from dataclasses import dataclass from decimal import Decimal as D -from typing import Any, NamedTuple +from typing import Any, Callable, NamedTuple -import numpy as np import pytest -from numpy.testing import assert_array_equal import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal @pytest.fixture(scope="module") @@ -140,6 +139,14 @@ def test_decimal_cast() -> None: assert result.to_dict(as_series=False) == expected +def test_decimal_cast_no_scale() -> None: + s = pl.Series().cast(pl.Decimal) + assert s.dtype == pl.Decimal(precision=None, scale=0) + + s = pl.Series([D("10.0")]).cast(pl.Decimal) + assert s.dtype == pl.Decimal(precision=None, scale=1) + + def test_decimal_scale_precision_roundtrip(monkeypatch: Any) -> None: monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1") assert pl.from_arrow(pl.Series("dec", [D("10.0")]).to_arrow()).item() == D("10.0") @@ -188,6 +195,34 @@ def test_read_csv_decimal(monkeypatch: Any) -> None: ] +def test_decimal_eq_number() -> None: + a = pl.Series([D("1.5"), D("22.25"), D("10.0")], dtype=pl.Decimal) + assert_series_equal(a == 1, pl.Series([False, False, False])) + assert_series_equal(a == 1.5, pl.Series([True, False, False])) + assert_series_equal(a == D("1.5"), pl.Series([True, False, False])) + assert_series_equal(a == pl.Series([D("1.5")]), pl.Series([True, False, False])) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, True, True, True])), + (operator.lt, pl.Series([None, False, False, False, True, True])), + (operator.ge, pl.Series([None, True, True, True, False, False])), + (operator.gt, pl.Series([None, False, False, False, False, False])), + ], +) +def test_decimal_compare( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series( + [None, D("1.2"), D("2.13"), D("4.99"), D("2.13"), D("1.2")], dtype=pl.Decimal + ) + s2 = pl.Series([None, D("1.200"), D("2.13"), D("4.99"), D("4.99"), D("2.13")]) + + assert_series_equal(op(s, s2), expected) + + def test_decimal_arithmetic() -> None: df = pl.DataFrame( { @@ -195,25 +230,29 @@ def test_decimal_arithmetic() -> None: "b": [D("20.1"), D("10.19"), D("39.21")], } ) + dt = pl.Decimal(20, 10) out = df.select( out1=pl.col("a") * pl.col("b"), out2=pl.col("a") + pl.col("b"), out3=pl.col("a") / pl.col("b"), out4=pl.col("a") - pl.col("b"), + out5=pl.col("a").cast(dt) / pl.col("b").cast(dt), ) assert out.dtypes == [ + pl.Decimal(precision=None, scale=4), pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=6), pl.Decimal(precision=None, scale=2), - pl.Decimal(precision=None, scale=2), - pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=14), ] assert out.to_dict(as_series=False) == { - "out1": [D("2.01"), D("102.91"), D("3921.39")], + "out1": [D("2.0100"), D("102.9190"), D("3921.3921")], "out2": [D("20.20"), D("20.29"), D("139.22")], - "out3": [D("0.00"), D("0.99"), D("2.55")], + "out3": [D("0.004975"), D("0.991167"), D("2.550624")], "out4": [D("-20.00"), D("-0.09"), D("60.80")], + "out5": [D("0.00497512437810"), D("0.99116781157998"), D("2.55062484060188")], } @@ -222,26 +261,27 @@ def test_decimal_series_value_arithmetic() -> None: out1 = s + 10 out2 = s + D("10") - with pytest.raises(pl.InvalidOperationError): - s + D("10.0001") + out3 = s + D("10.0001") out4 = s * 2 / 3 out5 = s / D("1.5") out6 = s - 5 assert out1.dtype == pl.Decimal(precision=None, scale=2) assert out2.dtype == pl.Decimal(precision=None, scale=2) - assert out4.dtype == pl.Decimal(precision=None, scale=2) - assert out5.dtype == pl.Decimal(precision=None, scale=2) + assert out3.dtype == pl.Decimal(precision=None, scale=4) + assert out4.dtype == pl.Decimal(precision=None, scale=6) + assert out5.dtype == pl.Decimal(precision=None, scale=6) assert out6.dtype == pl.Decimal(precision=None, scale=2) assert out1.to_list() == [D("10.1"), D("20.1"), D("110.01")] assert out2.to_list() == [D("10.1"), D("20.1"), D("110.01")] + assert out3.to_list() == [D("10.1001"), D("20.1001"), D("110.0101")] assert out4.to_list() == [ - D("0.06"), - D("6.73"), - D("66.67"), + D("0.066666"), + D("6.733333"), + D("66.673333"), ] # TODO: do we want floor instead of round? - assert out5.to_list() == [D("0.06"), D("6.73"), D("66.67")] + assert out5.to_list() == [D("0.066666"), D("6.733333"), D("66.673333")] assert out6.to_list() == [D("-4.9"), D("5.1"), D("95.01")] @@ -299,18 +339,10 @@ def test_decimal_write_parquet_12375() -> None: df.write_parquet(f) -@pytest.mark.parametrize("use_pyarrow", [True, False]) -def test_decimal_numpy_export(use_pyarrow: bool) -> None: - decimal_data = [D("1.234"), D("2.345"), D("-3.456")] - - s = pl.Series("n", decimal_data) - df = s.to_frame() - - assert_array_equal( - np.array(decimal_data), - s.to_numpy(use_pyarrow=use_pyarrow), - ) - assert_array_equal( - np.array(decimal_data).reshape((-1, 1)), - df.to_numpy(use_pyarrow=use_pyarrow), - ) +def test_decimal_list_get_13847() -> None: + with pl.Config() as cfg: + cfg.activate_decimals() + df = pl.DataFrame({"a": [[D("1.1"), D("1.2")], [D("2.1")]]}) + out = df.select(pl.col("a").list.get(0)) + expected = pl.DataFrame({"a": [D("1.1"), D("2.1")]}) + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/datatypes/test_duration.py b/py-polars/tests/unit/datatypes/test_duration.py index e9db9940c5b5..5f6638bec19c 100644 --- a/py-polars/tests/unit/datatypes/test_duration.py +++ b/py-polars/tests/unit/datatypes/test_duration.py @@ -1,5 +1,7 @@ from datetime import timedelta +import pytest + import polars as pl from polars.testing import assert_frame_equal @@ -20,7 +22,9 @@ def test_duration_cum_sum() -> None: def test_duration_std_var() -> None: - df = pl.DataFrame({"duration": [10, 5, 3]}, schema={"duration": pl.Duration}) + df = pl.DataFrame( + {"duration": [1000, 5000, 3000]}, schema={"duration": pl.Duration} + ) result = df.select( pl.col("duration").var().name.suffix("_var"), @@ -31,15 +35,27 @@ def test_duration_std_var() -> None: [ pl.Series( "duration_var", - [timedelta(microseconds=13)], - dtype=pl.Duration(time_unit="us"), + [timedelta(microseconds=4000)], + dtype=pl.Duration(time_unit="ms"), ), pl.Series( "duration_std", - [timedelta(microseconds=3)], + [timedelta(microseconds=2000)], dtype=pl.Duration(time_unit="us"), ), ] ) assert_frame_equal(result, expected) + + +def test_series_duration_std_var() -> None: + s = pl.Series([timedelta(days=1), timedelta(days=2), timedelta(days=4)]) + assert s.std() == timedelta(days=1, seconds=45578, microseconds=180014) + assert s.var() == timedelta(days=201600000) + + +def test_series_duration_var_overflow() -> None: + s = pl.Series([timedelta(days=10), timedelta(days=20), timedelta(days=40)]) + with pytest.raises(pl.PolarsPanicError, match="OverflowError"): + s.var() diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 8afbee7d9b78..6a3f5b39f814 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -9,7 +9,7 @@ import polars as pl from polars import StringCache -from polars.testing import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_enum_creation() -> None: @@ -364,7 +364,7 @@ def test_enum_categories_unique() -> None: def test_enum_categories_series_input() -> None: - categories = pl.Series("a", ["x", "y", "z"]) + categories = pl.Series("a", ["a", "b", "c"]) dtype = pl.Enum(categories) assert_series_equal(dtype.categories, categories.alias("category")) @@ -402,3 +402,22 @@ def test_enum_cast_from_other_integer_dtype_oob() -> None: pl.ComputeError, match="conversion from `u64` to `u32` failed in column" ): series.cast(enum_dtype) + + +def test_enum_creating_col_expr() -> None: + df = pl.DataFrame( + { + "col1": ["a", "b", "c"], + "col2": ["d", "e", "f"], + "col3": ["g", "h", "i"], + }, + schema={ + "col1": pl.Enum(["a", "b", "c"]), + "col2": pl.Categorical(), + "col3": pl.Enum(["g", "h", "i"]), + }, + ) + + out = df.select(pl.col(pl.Enum)) + expected = df.select("col1", "col3") + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/datatypes/test_float.py b/py-polars/tests/unit/datatypes/test_float.py index 62c975b26298..aef21b0d444d 100644 --- a/py-polars/tests/unit/datatypes/test_float.py +++ b/py-polars/tests/unit/datatypes/test_float.py @@ -1,4 +1,7 @@ +import pytest + import polars as pl +from polars.testing import assert_series_equal def test_nan_in_group_by_agg() -> None: @@ -32,3 +35,238 @@ def test_nan_aggregations() -> None: str(df.group_by("b").agg(aggs).to_dict(as_series=False)) == "{'b': [1], 'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}" ) + + +@pytest.mark.parametrize( + ("s", "expect"), + [ + ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ), + pl.Series("x", [None, 0.0, 1.0, float("nan")]), + ), + ( + # No nulls + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + ], + ), + pl.Series("x", [0.0, 1.0, float("nan")]), + ), + ], +) +def test_unique(s: pl.Series, expect: pl.Series) -> None: + out = s.unique() + assert_series_equal(expect, out) + + out = s.n_unique() # type: ignore[assignment] + assert expect.len() == out + + out = s.gather(s.arg_unique()).sort() + assert_series_equal(expect, out) + + +def test_unique_counts() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + expect = pl.Series("x", [2, 2, 1, 1], dtype=pl.UInt32) + out = s.unique_counts() + assert_series_equal(expect, out) + + +def test_hash() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ).hash() + + # check them against each other since hash is not stable + assert s.item(0) == s.item(1) # hash(-0.0) == hash(0.0) + assert s.item(2) == s.item(3) # hash(float('-nan')) == hash(float('nan')) + + +def test_group_by() -> None: + # Test num_groups_proxy + # * -0.0 and 0.0 in same groups + # * -nan and nan in same groups + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + expect = pl.Series("index", [[0, 1], [2, 3], [4], [5]], dtype=pl.List(pl.UInt32)) + expect_no_null = expect.head(3) + + for group_keys in (("x",), ("x", "a")): + for maintain_order in (True, False): + for drop_nulls in (True, False): + out = df + if drop_nulls: + out = out.drop_nulls() + + out = ( + out.group_by(group_keys, maintain_order=maintain_order) # type: ignore[assignment] + .agg("index") + .sort(pl.col("index").list.get(0)) + .select("index") + .to_series() + ) + + if drop_nulls: + assert_series_equal(expect_no_null, out) # type: ignore[arg-type] + else: + assert_series_equal(expect, out) # type: ignore[arg-type] + + +def test_joins() -> None: + # Test that -0.0 joins with 0.0 and nan joins with nan + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + rhs = ( + pl.Series("x", [0.0, float("nan"), 3.0]) + .to_frame() + .with_columns(a=pl.lit("a"), rhs=True) + ) + + for join_on in ( + # Single and multiple keys + ("x",), + ( + "x", + "a", + ), + ): + how = "left" + expect = pl.Series("rhs", [True, True, True, True, None, None]) + out = df.join(rhs, on=join_on, how=how).sort("index").select("rhs").to_series() # type: ignore[arg-type] + assert_series_equal(expect, out) + + how = "inner" + expect = pl.Series("index", [0, 1, 2, 3], dtype=pl.UInt32) + out = ( + df.join(rhs, on=join_on, how=how).sort("index").select("index").to_series() # type: ignore[arg-type] + ) + assert_series_equal(expect, out) + + how = "outer" + expect = pl.Series("rhs", [True, True, True, True, None, None, True]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + how = "semi" + expect = pl.Series("x", [-0.0, 0.0, float("-nan"), float("nan")]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + how = "anti" + expect = pl.Series("x", [1.0, None]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + # test asof + # note that nans never join because nans are always greater than the other + # side of the comparison (i.e. NaN > tolerance) + expect = pl.Series("rhs", [True, True, None, None, None, None]) + out = ( + df.sort("x") + .join_asof(rhs.sort("x"), on="x", tolerance=0) + .sort("index") + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + +def test_first_last_distinct() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + + assert_series_equal( + pl.Series("x", [True, False, True, False, True, True]), s.is_first_distinct() + ) + + assert_series_equal( + pl.Series("x", [False, True, False, True, True, True]), s.is_last_distinct() + ) diff --git a/py-polars/tests/unit/datatypes/test_integer.py b/py-polars/tests/unit/datatypes/test_integer.py index 1d3ef39dacb7..306154e6d936 100644 --- a/py-polars/tests/unit/datatypes/test_integer.py +++ b/py-polars/tests/unit/datatypes/test_integer.py @@ -13,3 +13,13 @@ def test_integer_float_functions() -> None: "nan": [False, False], "not_na": [True, True], } + + +def test_int_negate_operation() -> None: + assert pl.Series([1, 2, 3, 4, 50912341409]).not_().to_list() == [ + -2, + -3, + -4, + -5, + -50912341410, + ] diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 627971565065..f439781b4422 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -1,7 +1,7 @@ from __future__ import annotations import pickle -from datetime import date, datetime, time +from datetime import date, datetime, time, timedelta from decimal import Decimal from typing import TYPE_CHECKING, Any @@ -664,3 +664,91 @@ def test_as_list_logical_type() -> None: assert df.group_by(True).agg( pl.col("timestamp").gather(pl.col("value").arg_max()) ).to_dict(as_series=False) == {"literal": [True], "timestamp": [[date(2000, 1, 1)]]} + + +@pytest.fixture() +def data_dispersion() -> pl.DataFrame: + return pl.DataFrame( + { + "int": [[1, 2, 3, 4, 5]], + "float": [[1.0, 2.0, 3.0, 4.0, 5.0]], + "duration": [[1000, 2000, 3000, 4000, 5000]], + }, + schema={ + "int": pl.List(pl.Int64), + "float": pl.List(pl.Float64), + "duration": pl.List(pl.Duration), + }, + ) + + +def test_list_var(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.var().name.suffix("_var"), + pl.col("float").list.var().name.suffix("_var"), + pl.col("duration").list.var().name.suffix("_var"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_var", [2.5], dtype=pl.Float64), + pl.Series("float_var", [2.5], dtype=pl.Float64), + pl.Series( + "duration_var", + [timedelta(microseconds=2000)], + dtype=pl.Duration(time_unit="ms"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_std(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.std().name.suffix("_std"), + pl.col("float").list.std().name.suffix("_std"), + pl.col("duration").list.std().name.suffix("_std"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series("float_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series( + "duration_std", + [timedelta(microseconds=1581)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_median(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.median().name.suffix("_median"), + pl.col("float").list.median().name.suffix("_median"), + pl.col("duration").list.median().name.suffix("_median"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_median", [3.0], dtype=pl.Float64), + pl.Series("float_median", [3.0], dtype=pl.Float64), + pl.Series( + "duration_median", + [timedelta(microseconds=3000)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/datatypes/test_null.py b/py-polars/tests/unit/datatypes/test_null.py index 11b4355667e9..3d8db5ec5f2b 100644 --- a/py-polars/tests/unit/datatypes/test_null.py +++ b/py-polars/tests/unit/datatypes/test_null.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from typing import Any + +import pytest + import polars as pl from polars.testing import assert_frame_equal @@ -22,3 +28,52 @@ def test_null_grouping_12950() -> None: assert pl.DataFrame({"x": None}).slice(0, 0).unique().to_dict(as_series=False) == { "x": [] } + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (pl.Expr.gt, [None, None]), + (pl.Expr.lt, [None, None]), + (pl.Expr.ge, [None, None]), + (pl.Expr.le, [None, None]), + (pl.Expr.eq, [None, None]), + (pl.Expr.eq_missing, [True, True]), + (pl.Expr.ne, [None, None]), + (pl.Expr.ne_missing, [False, False]), + ], +) +def test_null_comp_14118(op: Any, expected: list[None | bool]) -> None: + df = pl.DataFrame( + { + "a": [None, None], + "b": [None, None], + } + ) + + output_df = df.select( + cmp=op(pl.col("a"), pl.col("b")), + broadcast_lhs=op(pl.lit(None), pl.col("b")), + broadcast_rhs=op(pl.col("a"), pl.lit(None)), + ) + + expected_df = pl.DataFrame( + { + "cmp": expected, + "broadcast_lhs": expected, + "broadcast_rhs": expected, + }, + schema={ + "cmp": pl.Boolean, + "broadcast_lhs": pl.Boolean, + "broadcast_rhs": pl.Boolean, + }, + ) + assert_frame_equal(output_df, expected_df) + + +def test_null_hash_rows_14100() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [None, None, None, None]}) + assert df.hash_rows().dtype == pl.UInt64 + assert df["b"].hash().dtype == pl.UInt64 + assert df.select([pl.col("b").hash().alias("foo")])["foo"].dtype == pl.UInt64 diff --git a/py-polars/tests/unit/datatypes/test_object.py b/py-polars/tests/unit/datatypes/test_object.py index 768e64d38514..0788f9e2c079 100644 --- a/py-polars/tests/unit/datatypes/test_object.py +++ b/py-polars/tests/unit/datatypes/test_object.py @@ -1,3 +1,4 @@ +from pathlib import Path from uuid import uuid4 import numpy as np @@ -153,3 +154,9 @@ def test_null_obj_str_13512() -> None: "│ 1 ┆ null │\n" "└─────┴────────┘" ) + + +def test_format_object_series_14267() -> None: + s = pl.Series([Path(), Path("abc")]) + expected = "shape: (2,)\n" "Series: '' [o][object]\n" "[\n" "\t.\n" "\tabc\n" "]" + assert str(s) == expected diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index c50a31117964..3cddc1d7aaee 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -3,7 +3,7 @@ import contextlib import io from datetime import date, datetime, time, timedelta, timezone -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import pandas as pd @@ -12,7 +12,11 @@ import polars as pl from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, TEMPORAL_DTYPES -from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning +from polars.exceptions import ( + ComputeError, + PolarsInefficientMapWarning, + TimeZoneAwareConstructorWarning, +) from polars.testing import ( assert_frame_equal, assert_series_equal, @@ -203,20 +207,6 @@ def test_from_pydatetime() -> None: assert s.dt[0] == dates[0] -@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) -def test_from_numpy_timedelta(time_unit: Literal["ns", "us", "ms"]) -> None: - s = pl.Series( - "name", - np.array( - [timedelta(days=1), timedelta(seconds=1)], dtype=f"timedelta64[{time_unit}]" - ), - ) - assert s.dtype == pl.Duration(time_unit) - assert s.name == "name" - assert s.dt[0] == timedelta(days=1) - assert s.dt[1] == timedelta(seconds=1) - - def test_int_to_python_datetime() -> None: df = pl.DataFrame({"a": [100_000_000, 200_000_000]}).with_columns( [ @@ -283,43 +273,10 @@ def test_int_to_python_timedelta() -> None: ] assert df.select( - [pl.col(col).dt.timestamp() for col in ("c", "d", "e")] + [pl.col(col).cast(pl.Int64) for col in ("c", "d", "e")] ).rows() == [(100001, 100001, 100001), (200002, 200002, 200002)] -def test_from_numpy() -> None: - # note: numpy timeunit support is limited to those supported by polars. - # as a result, datetime64[s] raises - x = np.asarray(range(100_000, 200_000, 10_000), dtype="datetime64[s]") - with pytest.raises(ValueError, match="Please cast to the closest supported unit"): - pl.Series(x) - - -@pytest.mark.parametrize( - ("numpy_time_unit", "expected_values", "expected_dtype"), - [ - ("ns", ["1970-01-02T01:12:34.123456789"], pl.Datetime("ns")), - ("us", ["1970-01-02T01:12:34.123456"], pl.Datetime("us")), - ("ms", ["1970-01-02T01:12:34.123"], pl.Datetime("ms")), - ("D", ["1970-01-02"], pl.Date), - ], -) -def test_from_numpy_supported_units( - numpy_time_unit: str, - expected_values: list[str], - expected_dtype: PolarsTemporalType, -) -> None: - values = np.array( - ["1970-01-02T01:12:34.123456789123456789"], - dtype=f"datetime64[{numpy_time_unit}]", - ) - result = pl.from_numpy(values) - expected = ( - pl.Series("column_0", expected_values).str.strptime(expected_dtype).to_frame() - ) - assert_frame_equal(result, expected) - - def test_datetime_consistency() -> None: dt = datetime(2022, 7, 5, 10, 30, 45, 123455) df = pl.DataFrame({"date": [dt]}) @@ -480,36 +437,6 @@ def test_rows() -> None: assert rows[0][1] == datetime(1970, 1, 1, 0, 2, 3, 543000) -def test_series_to_numpy() -> None: - s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date) - s1 = pl.Series( - "datetime", [datetime(2021, 1, 2, 3, 4, 5), datetime(2021, 2, 3, 4, 5, 6)] - ) - s2 = pl.datetime_range( - datetime(2021, 1, 1, 0), - datetime(2021, 1, 1, 1), - interval="1h", - time_unit="ms", - eager=True, - ) - assert str(s0.to_numpy()) == "['2308-04-02' '2746-02-20' '1973-05-28']" - assert ( - str(s1.to_numpy()[:2]) - == "['2021-01-02T03:04:05.000000' '2021-02-03T04:05:06.000000']" - ) - assert ( - str(s2.to_numpy()[:2]) - == "['2021-01-01T00:00:00.000' '2021-01-01T01:00:00.000']" - ) - s3 = pl.Series([timedelta(hours=1), timedelta(hours=-2)]) - out = np.array([3_600_000_000_000, -7_200_000_000_000], dtype="timedelta64[ns]") - assert (s3.to_numpy() == out).all() - - s4 = pl.Series([time(10, 30, 45), time(23, 59, 59)]) - out = np.array([time(10, 30, 45), time(23, 59, 59)], dtype="object") - assert (s4.to_numpy() == out).all() - - @pytest.mark.parametrize( ("one", "two"), [ @@ -1024,45 +951,50 @@ def test_temporal_dtypes_map_elements( ) const_dtm = datetime(2010, 9, 12) - assert_frame_equal( - df.with_columns( - [ - # don't actually do any of this; native expressions are MUCH faster ;) - pl.col("timestamp") - .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) - .alias("const_dtm"), - pl.col("timestamp") - .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) - .alias("date"), - pl.col("timestamp") - .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) - .alias("time"), - ] - ), - pl.DataFrame( - [ - ( - datetime(2010, 9, 12, 10, 19, 54), - datetime(2010, 9, 12, 0, 0), - date(2010, 9, 12), - time(10, 19, 54), - ), - (None, expected_value, None, None), - ( - datetime(2009, 2, 13, 23, 31, 30), - datetime(2010, 9, 12, 0, 0), - date(2009, 2, 13), - time(23, 31, 30), - ), - ], - schema={ - "timestamp": pl.Datetime("ms"), - "const_dtm": pl.Datetime("us"), - "date": pl.Date, - "time": pl.Time, - }, - ), - ) + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Replace this expression.*lambda x:", + ): + assert_frame_equal( + df.with_columns( + [ + # don't actually do this; native expressions are MUCH faster ;) + pl.col("timestamp") + .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) + .alias("const_dtm"), + # note: the below now trigger a PolarsInefficientMapWarning + pl.col("timestamp") + .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) + .alias("date"), + pl.col("timestamp") + .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) + .alias("time"), + ] + ), + pl.DataFrame( + [ + ( + datetime(2010, 9, 12, 10, 19, 54), + datetime(2010, 9, 12, 0, 0), + date(2010, 9, 12), + time(10, 19, 54), + ), + (None, expected_value, None, None), + ( + datetime(2009, 2, 13, 23, 31, 30), + datetime(2010, 9, 12, 0, 0), + date(2009, 2, 13), + time(23, 31, 30), + ), + ], + schema={ + "timestamp": pl.Datetime("ms"), + "const_dtm": pl.Datetime("us"), + "date": pl.Date, + "time": pl.Time, + }, + ), + ) def test_timelike_init() -> None: @@ -1343,22 +1275,6 @@ def test_rolling_by_() -> None: } -def test_date_to_time_cast_5111() -> None: - # check date -> time casts (fast-path: always 00:00:00) - df = pl.DataFrame( - { - "xyz": [ - date(1969, 1, 1), - date(1990, 3, 8), - date(2000, 6, 16), - date(2010, 9, 24), - date(2022, 12, 31), - ] - } - ).with_columns(pl.col("xyz").cast(pl.Time)) - assert df["xyz"].to_list() == [time(0), time(0), time(0), time(0), time(0)] - - def test_sum_duration() -> None: assert pl.DataFrame( [ @@ -1579,11 +1495,13 @@ def test_convert_time_zone_lazy_schema() -> None: def test_convert_time_zone_on_tz_naive() -> None: ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime) - with pytest.raises( - ComputeError, - match="cannot call `convert_time_zone` on tz-naive; set a time zone first with `replace_time_zone`", - ): - ts.dt.convert_time_zone("Africa/Bamako") + result = ts.dt.convert_time_zone("Asia/Kathmandu").item() + expected = datetime(2020, 1, 1, 5, 45, tzinfo=ZoneInfo(key="Asia/Kathmandu")) + assert result == expected + result = ( + ts.dt.replace_time_zone("UTC").dt.convert_time_zone("Asia/Kathmandu").item() + ) + assert result == expected def test_tz_aware_get_idx_5010() -> None: @@ -2743,3 +2661,9 @@ def test_rolling_duplicates() -> None: assert df.sort("ts").with_columns(pl.col("value").rolling_max("1d", by="ts"))[ "value" ].to_list() == [1, 1] + + +def test_datetime_time_unit_none_deprecated() -> None: + with pytest.deprecated_call(): + dtype = pl.Datetime(time_unit=None) # type: ignore[arg-type] + assert dtype.time_unit == "us" diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index e9b501d2a972..6e0be82f1557 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from datetime import date, datetime, time, timedelta, timezone +from datetime import date, datetime, timedelta, timezone from itertools import permutations from typing import Any, cast @@ -129,15 +129,15 @@ def test_unique_stable() -> None: def test_entropy() -> None: df = pl.DataFrame( { - "group": ["A", "A", "A", "B", "B", "B", "B"], - "id": [1, 2, 1, 4, 5, 4, 6], + "group": ["A", "A", "A", "B", "B", "B", "B", "C"], + "id": [1, 2, 1, 4, 5, 4, 6, 7], } ) result = df.group_by("group", maintain_order=True).agg( pl.col("id").entropy(normalize=True) ) expected = pl.DataFrame( - {"group": ["A", "B"], "id": [1.0397207708399179, 1.371381017771811]} + {"group": ["A", "B", "C"], "id": [1.0397207708399179, 1.371381017771811, 0.0]} ) assert_frame_equal(result, expected) @@ -430,34 +430,6 @@ def test_logical_boolean() -> None: df.select([(pl.col("a") > pl.col("b")) or (pl.col("b") > pl.col("b"))]) -# https://github.com/pola-rs/polars/issues/4951 -def test_ewm_with_multiple_chunks() -> None: - df0 = pl.DataFrame( - data=[ - ("w", 6.0, 1.0), - ("x", 5.0, 2.0), - ("y", 4.0, 3.0), - ("z", 3.0, 4.0), - ], - schema=["a", "b", "c"], - ).with_columns( - [ - pl.col(pl.Float64).log().diff().name.prefix("ld_"), - ] - ) - assert df0.n_chunks() == 1 - - # NOTE: We aren't testing whether `select` creates two chunks; - # we just need two chunks to properly test `ewm_mean` - df1 = df0.select(["ld_b", "ld_c"]) - assert df1.n_chunks() == 2 - - ewm_std = df1.with_columns( - pl.all().ewm_std(com=20).name.prefix("ewm_"), - ) - assert ewm_std.null_count().sum_horizontal()[0] == 4 - - def test_lit_dtypes() -> None: def lit_series(value: Any, dtype: pl.PolarsDataType | None) -> pl.Series: return pl.select(pl.lit(value, dtype=dtype)).to_series() @@ -477,7 +449,7 @@ def lit_series(value: Any, dtype: pl.PolarsDataType | None) -> pl.Series: "dtm_aware_0": lit_series(d, pl.Datetime("us", "Asia/Kathmandu")), "dtm_aware_1": lit_series(d_tz, pl.Datetime("us")), "dtm_aware_2": lit_series(d_tz, None), - "dtm_aware_3": lit_series(d, pl.Datetime(None, "Asia/Kathmandu")), + "dtm_aware_3": lit_series(d, pl.Datetime(time_zone="Asia/Kathmandu")), "dur_ms": lit_series(td, pl.Duration("ms")), "dur_us": lit_series(td, pl.Duration("us")), "dur_ns": lit_series(td, pl.Duration("ns")), @@ -688,58 +660,6 @@ def test_tail() -> None: } -@pytest.mark.parametrize( - ("const", "dtype"), - [ - (1, pl.Int8), - (4, pl.UInt32), - (4.5, pl.Float32), - (None, pl.Float64), - ("白鵬翔", pl.String), - (date.today(), pl.Date), - (datetime.now(), pl.Datetime("ns")), - (time(23, 59, 59), pl.Time), - (timedelta(hours=7, seconds=123), pl.Duration("ms")), - ], -) -def test_extend_constant(const: Any, dtype: pl.PolarsDataType) -> None: - df = pl.DataFrame({"a": pl.Series("s", [None], dtype=dtype)}) - - expected = pl.DataFrame( - {"a": pl.Series("s", [None, const, const, const], dtype=dtype)} - ) - - assert_frame_equal(df.select(pl.col("a").extend_constant(const, 3)), expected) - - -@pytest.mark.parametrize( - ("const", "dtype"), - [ - (1, pl.Int8), - (4, pl.UInt32), - (4.5, pl.Float32), - (None, pl.Float64), - ("白鵬翔", pl.String), - (date.today(), pl.Date), - (datetime.now(), pl.Datetime("ns")), - (time(23, 59, 59), pl.Time), - (timedelta(hours=7, seconds=123), pl.Duration("ms")), - ], -) -def test_extend_constant_arr(const: Any, dtype: pl.PolarsDataType) -> None: - """ - Test extend_constant in pl.List array. - - NOTE: This function currently fails when the Series is a list with a single [None] - value. Hence, this function does not begin with [[None]], but [[const]]. - """ - s = pl.Series("s", [[const]], dtype=pl.List(dtype)) - - expected = pl.Series("s", [[const, const, const, const]], dtype=pl.List(dtype)) - - assert_series_equal(s.list.eval(pl.element().extend_constant(const, 3)), expected) - - def test_is_not_deprecated() -> None: df = pl.DataFrame({"a": [True, False, True]}) diff --git a/py-polars/tests/unit/functions/aggregation/test_horizontal.py b/py-polars/tests/unit/functions/aggregation/test_horizontal.py index d5af947cb73f..4739c1698c53 100644 --- a/py-polars/tests/unit/functions/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/functions/aggregation/test_horizontal.py @@ -45,6 +45,21 @@ def test_all_any_horizontally() -> None: assert "horizontal" not in dfltr.explain().lower() +def test_all_any_single_input() -> None: + df = pl.DataFrame({"a": [0, 1, None]}) + out = df.select( + all=pl.all_horizontal(pl.col("a")), any=pl.any_horizontal(pl.col("a")) + ) + + expected = pl.DataFrame( + { + "all": [False, True, None], + "any": [False, True, None], + } + ) + assert_frame_equal(out, expected) + + def test_all_any_accept_expr() -> None: lf = pl.LazyFrame( { @@ -353,6 +368,10 @@ def test_horizontal_broadcasting() -> None: df.select(sum=pl.sum_horizontal(1, "a", "b")).to_series(), pl.Series("sum", [5, 10]), ) + assert_series_equal( + df.select(mean=pl.mean_horizontal(1, "a", "b")).to_series(), + pl.Series("mean", [1.66666, 3.33333]), + ) assert_series_equal( df.select(max=pl.max_horizontal(4, "*")).to_series(), pl.Series("max", [4, 6]) ) @@ -368,3 +387,37 @@ def test_horizontal_broadcasting() -> None: df.select(all=pl.all_horizontal(True, pl.Series([True, False]))).to_series(), pl.Series("all", [True, False]), ) + + +def test_mean_horizontal() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]}) + + result = lf.select(pl.mean_horizontal(pl.all())) + + expected = pl.LazyFrame({"a": [2.0, 3.0, 6.0]}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_mean_horizontal_no_columns() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]}) + + with pytest.raises(pl.ComputeError, match="number of output rows is unknown"): + lf.select(pl.mean_horizontal()) + + +def test_mean_horizontal_no_rows() -> None: + lf = pl.LazyFrame({"a": [], "b": [], "c": []}).with_columns(pl.all().cast(pl.Int64)) + + result = lf.select(pl.mean_horizontal(pl.all())) + + expected = pl.LazyFrame({"a": []}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_mean_horizontal_all_null() -> None: + lf = pl.LazyFrame({"a": [1, None], "b": [2, None], "c": [None, None]}) + + result = lf.select(pl.mean_horizontal(pl.all())) + + expected = pl.LazyFrame({"a": [1.5, None]}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_functions.py b/py-polars/tests/unit/functions/test_functions.py index e808d1c6bc1f..669d56313ab2 100644 --- a/py-polars/tests/unit/functions/test_functions.py +++ b/py-polars/tests/unit/functions/test_functions.py @@ -89,6 +89,19 @@ def test_concat_diagonal( assert_frame_equal(out, expected) +def test_concat_diagonal_relaxed_with_empty_frame() -> None: + df1 = pl.DataFrame() + df2 = pl.DataFrame( + { + "a": ["a", "b"], + "b": [1, 2], + } + ) + out = pl.concat((df1, df2), how="diagonal_relaxed") + expected = df2 + assert_frame_equal(out, expected) + + @pytest.mark.parametrize("lazy", [False, True]) def test_concat_horizontal(lazy: bool) -> None: a = pl.DataFrame({"a": ["a", "b"], "b": [1, 2]}) diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index d5a25122c22a..10c0602b47c3 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -287,53 +287,39 @@ def test_predicate_broadcast() -> None: pl.col("x"), ], ) -@pytest.mark.parametrize( - "df", - [ - pl.Series("x", 5 * [1], dtype=pl.Int32) - .to_frame() - .with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean)) - ], -) def test_single_element_broadcast( mask_expr: pl.Expr, truthy_expr: pl.Expr, falsy_expr: pl.Expr, - df: pl.DataFrame, ) -> None: + df = ( + pl.Series("x", 5 * [1], dtype=pl.Int32) + .to_frame() + .with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean)) + ) + # Given that the lengths of the mask, truthy and falsy are all either: # - Length 1 # - Equal length to the maximum length of the 3. # This test checks that all length-1 exprs are broadcasted to the max length. - - expect = df.select("x").head( + result = df.select( + pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr) + ) + expected = df.select("x").head( df.select( pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len()) ).item() ) + assert_frame_equal(result, expected) - actual = df.select( - pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr) - ) - - assert_frame_equal( - expect, - actual, - ) - - actual = ( + result = ( df.group_by(pl.lit(True).alias("key")) .agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)) .drop("key") ) - - if expect.height > 1: - actual = actual.explode(pl.all()) - - assert_frame_equal( - expect, - actual, - ) + if expected.height > 1: + result = result.explode(pl.all()) + assert_frame_equal(result, expected) @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/interchange/test_from_dataframe.py b/py-polars/tests/unit/interchange/test_from_dataframe.py index 8a9899d82ff6..62f34666f8b1 100644 --- a/py-polars/tests/unit/interchange/test_from_dataframe.py +++ b/py-polars/tests/unit/interchange/test_from_dataframe.py @@ -311,7 +311,7 @@ def test_column_to_series_use_sentinel_invalid_value() -> None: dtype = pl.Datetime("ns") mask_value = "invalid" - s = pl.Series([datetime(1970, 1, 1), mask_value, datetime(2000, 1, 1)], dtype=dtype) + s = pl.Series([datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype) col = PatchableColumn(s) col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value) diff --git a/py-polars/tests/unit/interop/numpy/__init__.py b/py-polars/tests/unit/interop/numpy/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py b/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py new file mode 100644 index 000000000000..5577525c4a83 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars.type_aliases import PolarsTemporalType + + +def test_from_numpy() -> None: + data = np.array([[1, 2, 3], [4, 5, 6]]) + df = pl.from_numpy( + data, + schema=["a", "b"], + orient="col", + schema_overrides={"a": pl.UInt32, "b": pl.UInt32}, + ) + assert df.shape == (3, 2) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] + assert df.schema == {"a": pl.UInt32, "b": pl.UInt32} + data2 = np.array(["foo", "bar"], dtype=object) + df2 = pl.from_numpy(data2) + assert df2.shape == (2, 1) + assert df2.rows() == [("foo",), ("bar",)] + assert df2.schema == {"column_0": pl.String} + with pytest.raises( + ValueError, + match="cannot create DataFrame from array with more than two dimensions", + ): + _ = pl.from_numpy(np.array([[[1]]])) + with pytest.raises( + ValueError, match="cannot create DataFrame from zero-dimensional array" + ): + _ = pl.from_numpy(np.array(1)) + + +def test_from_numpy_array_value() -> None: + df = pl.DataFrame({"A": [[2, 3]]}) + assert df.rows() == [([2, 3],)] + assert df.schema == {"A": pl.List(pl.Int64)} + + +def test_construct_from_ndarray_value() -> None: + array_cell = np.array([2, 3]) + df = pl.DataFrame(np.array([[array_cell, 4]], dtype=object)) + assert df.dtypes == [pl.Object, pl.Object] + to_numpy = df.to_numpy() + assert to_numpy.shape == (1, 2) + assert_array_equal(to_numpy[0][0], array_cell) + assert to_numpy[0][1] == 4 + + +def test_from_numpy_nparray_value() -> None: + array_cell = np.array([2, 3]) + df = pl.from_numpy(np.array([[array_cell, 4]], dtype=object)) + assert df.dtypes == [pl.Object, pl.Object] + to_numpy = df.to_numpy() + assert to_numpy.shape == (1, 2) + assert_array_equal(to_numpy[0][0], array_cell) + assert to_numpy[0][1] == 4 + + +def test_from_numpy_structured() -> None: + test_data = [ + ("Google Pixel 7", 521.90, True), + ("Apple iPhone 14 Pro", 999.00, True), + ("Samsung Galaxy S23 Ultra", 1199.99, False), + ("OnePlus 11", 699.00, True), + ] + # create a numpy structured array... + arr_structured = np.array( + test_data, + dtype=np.dtype( + [ + ("product", "U32"), + ("price_usd", "float64"), + ("in_stock", "bool"), + ] + ), + ) + # ...and also establish as a record array view + arr_records = arr_structured.view(np.recarray) + + # confirm that we can cleanly initialise a DataFrame from both, + # respecting the native dtypes and any schema overrides, etc. + for arr in (arr_structured, arr_records): + df = pl.DataFrame(data=arr).sort(by="price_usd", descending=True) + + assert df.schema == { + "product": pl.String, + "price_usd": pl.Float64, + "in_stock": pl.Boolean, + } + assert df.rows() == sorted(test_data, key=lambda row: -row[1]) + + for df in ( + pl.DataFrame( + data=arr, schema=["phone", ("price_usd", pl.Float32), "available"] + ), + pl.DataFrame( + data=arr, + schema=["phone", "price_usd", "available"], + schema_overrides={"price_usd": pl.Float32}, + ), + ): + assert df.schema == { + "phone": pl.String, + "price_usd": pl.Float32, + "available": pl.Boolean, + } + + +def test_from_numpy2() -> None: + # note: numpy timeunit support is limited to those supported by polars. + # as a result, datetime64[s] raises + x = np.asarray(range(100_000, 200_000, 10_000), dtype="datetime64[s]") + with pytest.raises(ValueError, match="Please cast to the closest supported unit"): + pl.Series(x) + + +@pytest.mark.parametrize( + ("numpy_time_unit", "expected_values", "expected_dtype"), + [ + ("ns", ["1970-01-02T01:12:34.123456789"], pl.Datetime("ns")), + ("us", ["1970-01-02T01:12:34.123456"], pl.Datetime("us")), + ("ms", ["1970-01-02T01:12:34.123"], pl.Datetime("ms")), + ("D", ["1970-01-02"], pl.Date), + ], +) +def test_from_numpy_supported_units( + numpy_time_unit: str, + expected_values: list[str], + expected_dtype: PolarsTemporalType, +) -> None: + values = np.array( + ["1970-01-02T01:12:34.123456789123456789"], + dtype=f"datetime64[{numpy_time_unit}]", + ) + result = pl.from_numpy(values) + expected = ( + pl.Series("column_0", expected_values).str.strptime(expected_dtype).to_frame() + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py new file mode 100644 index 000000000000..67a9088c36b7 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl + +if TYPE_CHECKING: + from polars.type_aliases import TimeUnit + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_from_numpy_timedelta(time_unit: TimeUnit) -> None: + s = pl.Series( + "name", + np.array( + [timedelta(days=1), timedelta(seconds=1)], dtype=f"timedelta64[{time_unit}]" + ), + ) + assert s.dtype == pl.Duration(time_unit) + assert s.name == "name" + assert s.dt[0] == timedelta(days=1) + assert s.dt[1] == timedelta(seconds=1) diff --git a/py-polars/tests/unit/interop/numpy/test_numpy.py b/py-polars/tests/unit/interop/numpy/test_numpy.py new file mode 100644 index 000000000000..8fe721537b38 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_numpy.py @@ -0,0 +1,78 @@ +from typing import Any + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl + + +@pytest.fixture( + params=[ + ("int8", [1, 3, 2], pl.Int8, np.int8), + ("int16", [1, 3, 2], pl.Int16, np.int16), + ("int32", [1, 3, 2], pl.Int32, np.int32), + ("int64", [1, 3, 2], pl.Int64, np.int64), + ("uint8", [1, 3, 2], pl.UInt8, np.uint8), + ("uint16", [1, 3, 2], pl.UInt16, np.uint16), + ("uint32", [1, 3, 2], pl.UInt32, np.uint32), + ("uint64", [1, 3, 2], pl.UInt64, np.uint64), + ("float32", [21.7, 21.8, 21], pl.Float32, np.float32), + ("float64", [21.7, 21.8, 21], pl.Float64, np.float64), + ("bool", [True, False, False], pl.Boolean, np.bool_), + ("object", [21.7, "string1", object()], pl.Object, np.object_), + ("str", ["string1", "string2", "string3"], pl.String, np.str_), + ("intc", [1, 3, 2], pl.Int32, np.intc), + ("uintc", [1, 3, 2], pl.UInt32, np.uintc), + ("str_fixed", ["string1", "string2", "string3"], pl.String, np.str_), + ( + "bytes", + [b"byte_string1", b"byte_string2", b"byte_string3"], + pl.Binary, + np.bytes_, + ), + ] +) +def numpy_interop_test_data(request: Any) -> Any: + return request.param + + +def test_df_from_numpy(numpy_interop_test_data: Any) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + df = pl.DataFrame({name: np.array(values, dtype=np_dtype)}) + assert [pl_dtype] == df.dtypes + + +def test_asarray(numpy_interop_test_data: Any) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + pl_series_to_numpy_array = np.asarray(pl.Series(name, values, pl_dtype)) + numpy_array = np.asarray(values, dtype=np_dtype) + assert_array_equal(pl_series_to_numpy_array, numpy_array) + + +@pytest.mark.parametrize("use_pyarrow", [True, False]) +def test_to_numpy(numpy_interop_test_data: Any, use_pyarrow: bool) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + pl_series_to_numpy_array = pl.Series(name, values, pl_dtype).to_numpy( + use_pyarrow=use_pyarrow + ) + numpy_array = np.asarray(values, dtype=np_dtype) + assert_array_equal(pl_series_to_numpy_array, numpy_array) + + +def test_numpy_to_lit() -> None: + out = pl.select(pl.lit(np.array([1, 2, 3]))).to_series().to_list() + assert out == [1, 2, 3] + out = pl.select(pl.lit(np.float32(0))).to_series().to_list() + assert out == [0.0] + + +def test_numpy_disambiguation() -> None: + a = np.array([1, 2]) + df = pl.DataFrame({"a": a}) + result = df.with_columns(b=a).to_dict(as_series=False) # type: ignore[arg-type] + expected = { + "a": [1, 2], + "b": [1, 2], + } + assert result == expected diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py new file mode 100644 index 000000000000..68b1f7696dd7 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from decimal import Decimal as D +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from numpy.testing import assert_array_equal, assert_equal + +import polars as pl + +if TYPE_CHECKING: + from polars.type_aliases import IndexOrder + + +@pytest.mark.parametrize( + ("order", "f_contiguous", "c_contiguous"), + [("fortran", True, False), ("c", False, True)], +) +def test_to_numpy(order: IndexOrder, f_contiguous: bool, c_contiguous: bool) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + out_array = df.to_numpy(order=order) + expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) + assert_array_equal(out_array, expected_array) + assert out_array.flags["F_CONTIGUOUS"] == f_contiguous + assert out_array.flags["C_CONTIGUOUS"] == c_contiguous + + structured_array = df.to_numpy(structured=True, order=order) + expected_array = np.array( + [(1, 1.0), (2, 2.0), (3, 3.0)], dtype=[("a", " None: + # round-trip structured array: validate init/export + structured_array = np.array( + [ + ("Google Pixel 7", 521.90, True), + ("Apple iPhone 14 Pro", 999.00, True), + ("OnePlus 11", 699.00, True), + ("Samsung Galaxy S23 Ultra", 1199.99, False), + ], + dtype=np.dtype( + [ + ("product", "U24"), + ("price_usd", "float64"), + ("in_stock", "bool"), + ] + ), + ) + df = pl.from_numpy(structured_array) + assert df.schema == { + "product": pl.String, + "price_usd": pl.Float64, + "in_stock": pl.Boolean, + } + exported_array = df.to_numpy(structured=True) + assert exported_array["product"].dtype == np.dtype("U24") + assert_array_equal(exported_array, structured_array) + + # none/nan values + df = pl.DataFrame({"x": ["a", None, "b"], "y": [5.5, None, -5.5]}) + exported_array = df.to_numpy(structured=True) + + assert exported_array.dtype == np.dtype([("x", object), ("y", float)]) + for name in df.columns: + assert_equal( + list(exported_array[name]), + ( + df[name].fill_null(float("nan")) + if df.schema[name].is_float() + else df[name] + ).to_list(), + ) + + +def test__array__() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + out_array = np.asarray(df.to_numpy()) + expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) + assert_array_equal(out_array, expected_array) + assert out_array.flags["F_CONTIGUOUS"] is True + + out_array = np.asarray(df.to_numpy(), np.uint8) + expected_array = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.uint8) + assert_array_equal(out_array, expected_array) + assert out_array.flags["F_CONTIGUOUS"] is True + + +def test_numpy_preserve_uint64_4112() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.col("a").hash()) + assert df.to_numpy().dtype == np.dtype("uint64") + assert df.to_numpy(structured=True).dtype == np.dtype([("a", "uint64")]) + + +@pytest.mark.parametrize("use_pyarrow", [True, False]) +def test_df_to_numpy_decimal(use_pyarrow: bool) -> None: + decimal_data = [D("1.234"), D("2.345"), D("-3.456")] + df = pl.Series("n", decimal_data).to_frame() + + result = df.to_numpy(use_pyarrow=use_pyarrow) + + expected = np.array(decimal_data).reshape((-1, 1)) + assert_array_equal(result, expected) + + +def test_df_to_numpy_zero_copy_path() -> None: + rows = 10 + cols = 5 + x = np.ones((rows, cols), order="F") + x[:, 1] = 2.0 + df = pl.DataFrame(x) + x = df.to_numpy(allow_copy=False) + assert x.flags["F_CONTIGUOUS"] + assert not x.flags["WRITEABLE"] + assert str(x[0, :]) == "[1. 2. 1. 1. 1.]" + + +def test_to_numpy_zero_copy_path_writable() -> None: + rows = 10 + cols = 5 + x = np.ones((rows, cols), order="F") + x[:, 1] = 2.0 + df = pl.DataFrame(x) + x = df.to_numpy(writable=True) + assert x.flags["WRITEABLE"] + + +def test_df_to_numpy_structured_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "cannot create structured array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(structured=True, allow_copy=False) + + +def test_df_to_numpy_writable_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "cannot create writable array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(allow_copy=False, writable=True) + + +def test_df_to_numpy_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2, None]}) + with pytest.raises(RuntimeError): + df.to_numpy(allow_copy=False) diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py new file mode 100644 index 000000000000..fe2909672fa8 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal as D +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import given, settings +from numpy.testing import assert_array_equal + +import polars as pl +from polars.testing.parametric import series + +if TYPE_CHECKING: + import numpy.typing as npt + + +def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None: + if s.len() == 0: + return + s_ptr = s._get_buffers()["values"]._get_buffer_info()[0] + arr_ptr = arr.__array_interface__["data"][0] + assert s_ptr == arr_ptr + + +def assert_allow_copy_false_raises(s: pl.Series) -> None: + with pytest.raises(ValueError, match="cannot return a zero-copy array"): + s.to_numpy(use_pyarrow=False, allow_copy=False) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Int8, np.int8), + (pl.Int16, np.int16), + (pl.Int32, np.int32), + (pl.Int64, np.int64), + (pl.UInt8, np.uint8), + (pl.UInt16, np.uint16), + (pl.UInt32, np.uint32), + (pl.UInt64, np.uint64), + (pl.Float32, np.float32), + (pl.Float64, np.float64), + ], +) +def test_series_to_numpy_numeric_zero_copy( + dtype: pl.PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + s = pl.Series([1, 2, 3]).cast(dtype) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) + + assert_zero_copy(s, result) + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Int8, np.float32), + (pl.Int16, np.float32), + (pl.Int32, np.float64), + (pl.Int64, np.float64), + (pl.UInt8, np.float32), + (pl.UInt16, np.float32), + (pl.UInt32, np.float64), + (pl.UInt64, np.float64), + (pl.Float32, np.float32), + (pl.Float64, np.float64), + ], +) +def test_series_to_numpy_numeric_with_nulls( + dtype: pl.PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + s = pl.Series([1, 2, None], dtype=dtype, strict=False) + result = s.to_numpy(use_pyarrow=False) + + assert result.tolist()[:-1] == s.to_list()[:-1] + assert np.isnan(result[-1]) + assert result.dtype == expected_dtype + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Duration, np.dtype("timedelta64[us]")), + (pl.Duration("ms"), np.dtype("timedelta64[ms]")), + (pl.Duration("us"), np.dtype("timedelta64[us]")), + (pl.Duration("ns"), np.dtype("timedelta64[ns]")), + (pl.Datetime, np.dtype("datetime64[us]")), + (pl.Datetime("ms"), np.dtype("datetime64[ms]")), + (pl.Datetime("us"), np.dtype("datetime64[us]")), + (pl.Datetime("ns"), np.dtype("datetime64[ns]")), + ], +) +def test_series_to_numpy_temporal_zero_copy( + dtype: pl.PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + values = [0, 2_000, 1_000_000] + s = pl.Series(values, dtype=dtype, strict=False) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) + + assert_zero_copy(s, result) + # NumPy tolist returns integers for ns precision + if s.dtype.time_unit == "ns": # type: ignore[attr-defined] + assert result.tolist() == values + else: + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + + +def test_series_to_numpy_datetime_with_tz_zero_copy() -> None: + values = [datetime(1970, 1, 1), datetime(2024, 2, 28)] + s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam") + result = s.to_numpy(use_pyarrow=False, allow_copy=False) + + assert_zero_copy(s, result) + assert result.tolist() == values + assert result.dtype == np.dtype("datetime64[us]") + + +def test_series_to_numpy_date() -> None: + values = [date(1970, 1, 1), date(2024, 2, 28)] + s = pl.Series(values) + + result = s.to_numpy(use_pyarrow=False) + + assert s.to_list() == result.tolist() + assert result.dtype == np.dtype("datetime64[D]") + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Date, np.dtype("datetime64[D]")), + (pl.Duration("ms"), np.dtype("timedelta64[ms]")), + (pl.Duration("us"), np.dtype("timedelta64[us]")), + (pl.Duration("ns"), np.dtype("timedelta64[ns]")), + (pl.Datetime, np.dtype("datetime64[us]")), + (pl.Datetime("ms"), np.dtype("datetime64[ms]")), + (pl.Datetime("us"), np.dtype("datetime64[us]")), + (pl.Datetime("ns"), np.dtype("datetime64[ns]")), + ], +) +def test_series_to_numpy_temporal_with_nulls( + dtype: pl.PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + values = [0, 2_000, 1_000_000, None] + s = pl.Series(values, dtype=dtype, strict=False) + result = s.to_numpy(use_pyarrow=False) + + # NumPy tolist returns integers for ns precision + if getattr(s.dtype, "time_unit", None) == "ns": + assert result.tolist() == values + else: + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_datetime_with_tz_with_nulls() -> None: + values = [datetime(1970, 1, 1), datetime(2024, 2, 28), None] + s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam") + result = s.to_numpy(use_pyarrow=False) + + assert result.tolist() == values + assert result.dtype == np.dtype("datetime64[us]") + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Time, [time(10, 30, 45), time(23, 59, 59)]), + (pl.Categorical, ["a", "b", "a"]), + (pl.Enum(["a", "b", "c"]), ["a", "b", "a"]), + (pl.String, ["a", "bc", "def"]), + (pl.Binary, [b"a", b"bc", b"def"]), + (pl.Decimal, [D("1.234"), D("2.345"), D("-3.456")]), + (pl.Object, [Path(), Path("abc")]), + # TODO: Implement for List types + # (pl.List, [[1], [2, 3]]), + # (pl.List, [["a"], ["b", "c"], []]), + ], +) +@pytest.mark.parametrize("with_nulls", [False, True]) +def test_to_numpy_object_dtypes( + dtype: pl.PolarsDataType, values: list[Any], with_nulls: bool +) -> None: + if with_nulls: + values.append(None) + + s = pl.Series(values, dtype=dtype) + result = s.to_numpy(use_pyarrow=False) + + assert result.tolist() == values + assert result.dtype == np.object_ + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_bool() -> None: + s = pl.Series([True, False]) + result = s.to_numpy(use_pyarrow=False) + + assert s.to_list() == result.tolist() + assert result.dtype == np.bool_ + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_bool_with_nulls() -> None: + s = pl.Series([True, False, None]) + result = s.to_numpy(use_pyarrow=False) + + assert s.to_list() == result.tolist() + assert result.dtype == np.object_ + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_array_of_int() -> None: + values = [[1, 2], [3, 4], [5, 6]] + s = pl.Series(values, dtype=pl.Array(pl.Int64, 2)) + result = s.to_numpy(use_pyarrow=False) + + expected = np.array(values) + assert_array_equal(result, expected) + assert result.dtype == np.int64 + + +def test_series_to_numpy_array_of_str() -> None: + values = [["1", "2", "3"], ["4", "5", "10000"]] + s = pl.Series(values, dtype=pl.Array(pl.String, 3)) + result = s.to_numpy(use_pyarrow=False) + assert result.tolist() == values + assert result.dtype == np.object_ + + +@pytest.mark.skip( + reason="Currently bugged, see: https://github.com/pola-rs/polars/issues/14268" +) +def test_series_to_numpy_array_with_nulls() -> None: + values = [[1, 2], [3, 4], None] + s = pl.Series(values, dtype=pl.Array(pl.Int64, 2)) + result = s.to_numpy(use_pyarrow=False) + + expected = np.array([[1.0, 2.0], [3.0, 4.0], [np.nan, np.nan]]) + assert_array_equal(result, expected) + assert result.dtype == np.float64 + assert_allow_copy_false_raises(s) + + +def test_to_numpy_null() -> None: + s = pl.Series([None, None], dtype=pl.Null) + result = s.to_numpy(use_pyarrow=False) + expected = np.array([np.nan, np.nan], dtype=np.float32) + assert_array_equal(result, expected) + assert result.dtype == np.float32 + assert_allow_copy_false_raises(s) + + +def test_to_numpy_empty() -> None: + s = pl.Series(dtype=pl.String) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) + assert result.dtype == np.object_ + assert result.shape == (0,) + assert result.size == 0 + + +def test_to_numpy_chunked() -> None: + s1 = pl.Series([1, 2]) + s2 = pl.Series([3, 4]) + s = pl.concat([s1, s2], rechunk=False) + + result = s.to_numpy(use_pyarrow=False) + + assert result.tolist() == s.to_list() + assert result.dtype == np.int64 + assert_allow_copy_false_raises(s) + + +def test_zero_copy_only_deprecated() -> None: + values = [1, 2] + s = pl.Series([1, 2]) + with pytest.deprecated_call(): + result = s.to_numpy(zero_copy_only=True) + assert result.tolist() == values + + +def test_series_to_numpy_temporal() -> None: + s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date) + s1 = pl.Series( + "datetime", [datetime(2021, 1, 2, 3, 4, 5), datetime(2021, 2, 3, 4, 5, 6)] + ) + s2 = pl.datetime_range( + datetime(2021, 1, 1, 0), + datetime(2021, 1, 1, 1), + interval="1h", + time_unit="ms", + eager=True, + ) + assert str(s0.to_numpy()) == "['2308-04-02' '2746-02-20' '1973-05-28']" + assert ( + str(s1.to_numpy()[:2]) + == "['2021-01-02T03:04:05.000000' '2021-02-03T04:05:06.000000']" + ) + assert ( + str(s2.to_numpy()[:2]) + == "['2021-01-01T00:00:00.000' '2021-01-01T01:00:00.000']" + ) + s3 = pl.Series([timedelta(hours=1), timedelta(hours=-2)]) + out = np.array([3_600_000_000_000, -7_200_000_000_000], dtype="timedelta64[ns]") + assert (s3.to_numpy() == out).all() + + +@given( + s=series( + min_size=1, max_size=10, excluded_dtypes=[pl.Categorical, pl.List, pl.Struct] + ).filter( + lambda s: ( + getattr(s.dtype, "time_unit", None) != "ms" + and not (s.dtype == pl.String and s.str.contains("\x00").any()) + and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any()) + ) + ), +) +@settings(max_examples=250) +def test_series_to_numpy(s: pl.Series) -> None: + result = s.to_numpy(use_pyarrow=False) + + values = s.to_list() + dtype_map = { + pl.Datetime("ns"): "datetime64[ns]", + pl.Datetime("us"): "datetime64[us]", + pl.Duration("ns"): "timedelta64[ns]", + pl.Duration("us"): "timedelta64[us]", + } + np_dtype = dtype_map.get(s.dtype) # type: ignore[call-overload] + expected = np.array(values, dtype=np_dtype) + + assert_array_equal(result, expected) + + +@pytest.mark.parametrize("writable", [False, True]) +@pytest.mark.parametrize("pyarrow_available", [False, True]) +def test_to_numpy2( + writable: bool, pyarrow_available: bool, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", pyarrow_available) + + np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable) + + np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8)) + # Test if numpy array is readonly or writable. + assert np_array.flags.writeable == writable + + if writable: + np_array[1] += 10 + np.testing.assert_array_equal(np_array, np.array([1, 12, 3], dtype=np.uint8)) + + np_array_with_missing_values = pl.Series("a", [None, 2, 3], pl.UInt8).to_numpy( + writable=writable + ) + + np.testing.assert_array_equal( + np_array_with_missing_values, + np.array( + [np.nan, 2.0, 3.0], + dtype=(np.float64 if pyarrow_available else np.float32), + ), + ) + + if writable: + # As Null values can't be encoded natively in a numpy array, + # this array will never be a view. + assert np_array_with_missing_values.flags.writeable == writable + + +def test_view() -> None: + s = pl.Series("a", [1.0, 2.5, 3.0]) + with pytest.deprecated_call(): + result = s.view() + assert isinstance(result, np.ndarray) + assert np.all(result == np.array([1.0, 2.5, 3.0])) + + +def test_view_nulls() -> None: + s = pl.Series("b", [1, 2, None]) + assert s.has_validity() + with pytest.deprecated_call(), pytest.raises(AssertionError): + s.view() + + +def test_view_nulls_sliced() -> None: + s = pl.Series("b", [1, 2, None]) + sliced = s[:2] + with pytest.deprecated_call(): + view = sliced.view() + assert np.all(view == np.array([1, 2])) + assert not sliced.has_validity() + + +def test_view_ub() -> None: + # this would be UB if the series was dropped and not passed to the view + s = pl.Series([3, 1, 5]) + with pytest.deprecated_call(): + result = s.sort().view() + assert np.sum(result) == 9 diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py new file mode 100644 index 000000000000..8695d8d7e4b5 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any, cast + +import numpy as np + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_ufunc() -> None: + df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) + out = df.select( + [ + np.power(pl.col("a"), 2).alias("power_uint8"), # type: ignore[call-overload] + np.power(pl.col("a"), 2.0).alias("power_float64"), # type: ignore[call-overload] + np.power(pl.col("a"), 2, dtype=np.uint16).alias("power_uint16"), # type: ignore[call-overload] + ] + ) + expected = pl.DataFrame( + [ + pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), + pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), + ] + ) + assert_frame_equal(out, expected) + assert out.dtypes == expected.dtypes + + +def test_ufunc_expr_not_first() -> None: + """Check numpy ufunc expressions also work if expression not the first argument.""" + df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) + out = df.select( + [ + np.power(2.0, cast(Any, pl.col("a"))).alias("power"), + (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), + (np.array([2, 2, 2]) / cast(Any, pl.col("a"))).alias("divide_array"), + ] + ) + expected = pl.DataFrame( + [ + pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), + pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + pl.Series("divide_array", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + ] + ) + assert_frame_equal(out, expected) + + +def test_lazy_ufunc() -> None: + ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) + out = ldf.select( + [ + np.power(cast(Any, pl.col("a")), 2).alias("power_uint8"), + np.power(cast(Any, pl.col("a")), 2.0).alias("power_float64"), + np.power(cast(Any, pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"), + ] + ) + expected = pl.DataFrame( + [ + pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), + pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), + ] + ) + assert_frame_equal(out.collect(), expected) + + +def test_lazy_ufunc_expr_not_first() -> None: + """Check numpy ufunc expressions also work if expression not the first argument.""" + ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) + out = ldf.select( + [ + np.power(2.0, cast(Any, pl.col("a"))).alias("power"), + (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), + (np.array([2, 2, 2]) / cast(Any, pl.col("a"))).alias("divide_array"), + ] + ) + expected = pl.DataFrame( + [ + pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), + pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + pl.Series("divide_array", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + ] + ) + assert_frame_equal(out.collect(), expected) + + +def test_ufunc_recognition() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [1.1, 2.2, 3.3, 4.4]}) + assert_frame_equal(df.select(np.exp(pl.col("b"))), df.select(pl.col("b").exp())) + + +# https://github.com/pola-rs/polars/issues/6770 +def test_ufunc_multiple_expressions() -> None: + df = pl.DataFrame( + { + "v": [ + -4.293, + -2.4659, + -1.8378, + -0.2821, + -4.5649, + -3.8128, + -7.4274, + 3.3443, + 3.8604, + -4.2200, + ], + "u": [ + -11.2268, + 6.3478, + 7.1681, + 3.4986, + 2.7320, + -1.0695, + -10.1408, + 11.2327, + 6.6623, + -8.1412, + ], + } + ) + expected = np.arctan2(df.get_column("v"), df.get_column("u")) + result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload] + assert_series_equal(expected, result) # type: ignore[arg-type] + + +def test_grouped_ufunc() -> None: + df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]}) + df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1)) diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_series.py b/py-polars/tests/unit/interop/numpy/test_ufunc_series.py new file mode 100644 index 000000000000..917b54c9eba2 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_series.py @@ -0,0 +1,121 @@ +from typing import cast + +import numpy as np +from numpy.testing import assert_array_equal + +import polars as pl +from polars.testing import assert_series_equal + + +def test_ufunc() -> None: + # test if output dtype is calculated correctly. + s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32) + assert_series_equal( + cast(pl.Series, np.multiply(s_float32, 4)), + pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32), + ) + + s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64) + assert_series_equal( + cast(pl.Series, np.multiply(s_float64, 4)), + pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64), + ) + + s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16), + ) + + s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16), + ) + + s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32) + assert_series_equal( + cast(pl.Series, np.power(s_uint32, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint32, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32) + assert_series_equal( + cast(pl.Series, np.power(s_int32, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int32, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64) + assert_series_equal( + cast(pl.Series, np.power(s_uint64, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint64, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64) + assert_series_equal( + cast(pl.Series, np.power(s_int64, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int64, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + # test if null bitmask is preserved + a1 = pl.Series("a", [1.0, None, 3.0]) + b1 = cast(pl.Series, np.exp(a1)) + assert b1.null_count() == 1 + + # test if it works with chunked series. + a2 = pl.Series("a", [1.0, None, 3.0]) + b2 = pl.Series("b", [4.0, 5.0, None]) + a2.append(b2) + assert a2.n_chunks() == 2 + c2 = np.multiply(a2, 3) + assert_series_equal( + cast(pl.Series, c2), + pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]), + ) + + # Test if nulls propagate through ufuncs + a3 = pl.Series("a", [None, None, 3, 3]) + b3 = pl.Series("b", [None, 3, None, 3]) + assert_series_equal( + cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3]) + ) + + +def test_numpy_string_array() -> None: + s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String) + assert_array_equal( + np.char.capitalize(s_str), + np.array(["Aa", "Bb", "Cc", "Dd"], dtype=" Any: - return request.param - - -def test_df_from_numpy(numpy_interop_test_data: Any) -> None: - name, values, pl_dtype, np_dtype = numpy_interop_test_data - df = pl.DataFrame({name: np.array(values, dtype=np_dtype)}) - assert [pl_dtype] == df.dtypes - - -def test_asarray(numpy_interop_test_data: Any) -> None: - name, values, pl_dtype, np_dtype = numpy_interop_test_data - pl_series_to_numpy_array = np.asarray(pl.Series(name, values, pl_dtype)) - numpy_array = np.asarray(values, dtype=np_dtype) - assert_array_equal(pl_series_to_numpy_array, numpy_array) - - -@pytest.mark.parametrize("use_pyarrow", [True, False]) -def test_to_numpy(numpy_interop_test_data: Any, use_pyarrow: bool) -> None: - name, values, pl_dtype, np_dtype = numpy_interop_test_data - pl_series_to_numpy_array = pl.Series(name, values, pl_dtype).to_numpy( - use_pyarrow=use_pyarrow - ) - numpy_array = np.asarray(values, dtype=np_dtype) - assert_array_equal(pl_series_to_numpy_array, numpy_array) - - -@pytest.mark.parametrize("use_pyarrow", [True, False]) -@pytest.mark.parametrize("has_null", [True, False]) -@pytest.mark.parametrize("dtype", [pl.Time, pl.Boolean, pl.String]) -def test_to_numpy_no_zero_copy( - use_pyarrow: bool, has_null: bool, dtype: pl.PolarsDataType -) -> None: - data: list[Any] = ["a", None] if dtype == pl.String else [0, None] - series = pl.Series(data if has_null else data[:1], dtype=dtype) - with pytest.raises(ValueError): - series.to_numpy(zero_copy_only=True, use_pyarrow=use_pyarrow) - - -def test_to_numpy_empty_no_pyarrow() -> None: - series = pl.Series([], dtype=pl.Null) - result = series.to_numpy() - assert result.dtype == pl.Float32 - assert result.shape == (0,) - assert result.size == 0 +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType def test_from_pandas() -> None: @@ -182,6 +111,44 @@ def test_from_pandas_datetime() -> None: assert s[-1] == datetime(2021, 6, 24, 9, 0) +@pytest.mark.parametrize( + ("index_class", "index_data", "index_params", "expected_data", "expected_dtype"), + [ + (pd.Index, [100, 200, 300], {}, None, pl.Int64), + (pd.Index, [1, 2, 3], {"dtype": "uint32"}, None, pl.UInt32), + (pd.RangeIndex, 5, {}, [0, 1, 2, 3, 4], pl.Int64), + (pd.CategoricalIndex, ["N", "E", "S", "W"], {}, None, pl.Categorical), + ( + pd.DatetimeIndex, + [datetime(1960, 12, 31), datetime(2077, 10, 20)], + {"dtype": "datetime64[ms]"}, + None, + pl.Datetime("ms"), + ), + ( + pd.TimedeltaIndex, + ["24 hours", "2 days 8 hours", "3 days 42 seconds"], + {}, + [timedelta(1), timedelta(days=2, hours=8), timedelta(days=3, seconds=42)], + pl.Duration("ns"), + ), + ], +) +def test_from_pandas_index( + index_class: Any, + index_data: Any, + index_params: dict[str, Any], + expected_data: list[Any] | None, + expected_dtype: PolarsDataType, +) -> None: + if expected_data is None: + expected_data = index_data + + s = pl.from_pandas(index_class(index_data, **index_params)) + assert s.to_list() == expected_data + assert s.dtype == expected_dtype + + def test_from_pandas_include_indexes() -> None: data = { "dtm": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], @@ -408,83 +375,6 @@ def test_from_records() -> None: assert df.rows() == [(1, 4), (2, 5), (3, 6)] -def test_from_numpy() -> None: - data = np.array([[1, 2, 3], [4, 5, 6]]) - df = pl.from_numpy( - data, - schema=["a", "b"], - orient="col", - schema_overrides={"a": pl.UInt32, "b": pl.UInt32}, - ) - assert df.shape == (3, 2) - assert df.rows() == [(1, 4), (2, 5), (3, 6)] - assert df.schema == {"a": pl.UInt32, "b": pl.UInt32} - data2 = np.array(["foo", "bar"], dtype=object) - df2 = pl.from_numpy(data2) - assert df2.shape == (2, 1) - assert df2.rows() == [("foo",), ("bar",)] - assert df2.schema == {"column_0": pl.String} - with pytest.raises( - ValueError, - match="cannot create DataFrame from array with more than two dimensions", - ): - _ = pl.from_numpy(np.array([[[1]]])) - with pytest.raises( - ValueError, match="cannot create DataFrame from zero-dimensional array" - ): - _ = pl.from_numpy(np.array(1)) - - -def test_from_numpy_structured() -> None: - test_data = [ - ("Google Pixel 7", 521.90, True), - ("Apple iPhone 14 Pro", 999.00, True), - ("Samsung Galaxy S23 Ultra", 1199.99, False), - ("OnePlus 11", 699.00, True), - ] - # create a numpy structured array... - arr_structured = np.array( - test_data, - dtype=np.dtype( - [ - ("product", "U32"), - ("price_usd", "float64"), - ("in_stock", "bool"), - ] - ), - ) - # ...and also establish as a record array view - arr_records = arr_structured.view(np.recarray) - - # confirm that we can cleanly initialise a DataFrame from both, - # respecting the native dtypes and any schema overrides, etc. - for arr in (arr_structured, arr_records): - df = pl.DataFrame(data=arr).sort(by="price_usd", descending=True) - - assert df.schema == { - "product": pl.String, - "price_usd": pl.Float64, - "in_stock": pl.Boolean, - } - assert df.rows() == sorted(test_data, key=lambda row: -row[1]) - - for df in ( - pl.DataFrame( - data=arr, schema=["phone", ("price_usd", pl.Float32), "available"] - ), - pl.DataFrame( - data=arr, - schema=["phone", "price_usd", "available"], - schema_overrides={"price_usd": pl.Float32}, - ), - ): - assert df.schema == { - "phone": pl.String, - "price_usd": pl.Float32, - "available": pl.Boolean, - } - - def test_from_arrow() -> None: data = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) df = pl.from_arrow(data) @@ -567,13 +457,6 @@ def test_no_rechunk() -> None: assert pl.from_arrow(table["x"], rechunk=False).n_chunks() == 2 -def test_numpy_to_lit() -> None: - out = pl.select(pl.lit(np.array([1, 2, 3]))).to_series().to_list() - assert out == [1, 2, 3] - out = pl.select(pl.lit(np.float32(0))).to_series().to_list() - assert out == [0.0] - - def test_from_empty_pandas() -> None: pandas_df = pd.DataFrame( { @@ -663,12 +546,6 @@ def test_from_pyarrow_chunked_array() -> None: assert series.to_list() == [1, 2] -def test_numpy_preserve_uint64_4112() -> None: - df = pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.col("a").hash()) - assert df.to_numpy().dtype == np.dtype("uint64") - assert df.to_numpy(structured=True).dtype == np.dtype([("a", "uint64")]) - - def test_arrow_list_null_5697() -> None: # Create a pyarrow table with a list[null] column. pa_table = pa.table([[[None]]], names=["mycol"]) @@ -711,29 +588,6 @@ def test_from_pyarrow_map() -> None: } -def test_to_numpy_datelike() -> None: - s = pl.Series( - "dt", - [ - datetime(2022, 7, 5, 10, 30, 45, 123456), - None, - datetime(2023, 2, 5, 15, 22, 30, 987654), - ], - ) - assert str(s.to_numpy()) == str( - np.array( - ["2022-07-05T10:30:45.123456", "NaT", "2023-02-05T15:22:30.987654"], - dtype="datetime64[us]", - ) - ) - assert str(s.drop_nulls().to_numpy()) == str( - np.array( - ["2022-07-05T10:30:45.123456", "2023-02-05T15:22:30.987654"], - dtype="datetime64[us]", - ) - ) - - def test_from_fixed_size_binary_list() -> None: val = [[b"63A0B1C66575DD5708E1EB2B"]] arrow_array = pa.array(val, type=pa.list_(pa.binary(24))) diff --git a/py-polars/tests/unit/interop/test_numpy.py b/py-polars/tests/unit/interop/test_numpy.py deleted file mode 100644 index 9de3c616f308..000000000000 --- a/py-polars/tests/unit/interop/test_numpy.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import pytest - -import polars as pl - - -def test_view() -> None: - s = pl.Series("a", [1.0, 2.5, 3.0]) - result = s._view() - assert isinstance(result, np.ndarray) - assert np.all(result == np.array([1.0, 2.5, 3.0])) - - -def test_view_nulls() -> None: - s = pl.Series("b", [1, 2, None]) - assert s.has_validity() - with pytest.raises(AssertionError): - s._view() - - -def test_view_nulls_sliced() -> None: - s = pl.Series("b", [1, 2, None]) - sliced = s[:2] - assert np.all(sliced._view() == np.array([1, 2])) - assert not sliced.has_validity() - - -def test_view_ub() -> None: - # this would be UB if the series was dropped and not passed to the view - s = pl.Series([3, 1, 5]) - result = s.sort()._view() - assert np.sum(result) == 9 - - -def test_view_deprecated() -> None: - s = pl.Series("a", [1.0, 2.5, 3.0]) - with pytest.deprecated_call(): - result = s.view() - assert isinstance(result, np.ndarray) - assert np.all(result == np.array([1.0, 2.5, 3.0])) - - -def test_numpy_disambiguation() -> None: - a = np.array([1, 2]) - df = pl.DataFrame({"a": a}) - result = df.with_columns(b=a).to_dict(as_series=False) # type: ignore[arg-type] - expected = { - "a": [1, 2], - "b": [1, 2], - } - assert result == expected diff --git a/py-polars/tests/unit/interop/test_to_pandas.py b/py-polars/tests/unit/interop/test_to_pandas.py index 2d15c16dcd7f..061affd14954 100644 --- a/py-polars/tests/unit/interop/test_to_pandas.py +++ b/py-polars/tests/unit/interop/test_to_pandas.py @@ -179,3 +179,18 @@ def test_object_to_pandas_series(use_pyarrow_extension_array: bool) -> None: ), pd.Series(values, dtype=object, name="a"), ) + + +@pytest.mark.parametrize("polars_dtype", [pl.Categorical, pl.Enum(["a", "b"])]) +def test_series_to_pandas_categorical(polars_dtype: pl.PolarsDataType) -> None: + s = pl.Series("x", ["a", "b", "a"], dtype=polars_dtype) + result = s.to_pandas() + expected = pd.Series(["a", "b", "a"], name="x", dtype="category") + pd.testing.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("polars_dtype", [pl.Categorical, pl.Enum(["a", "b"])]) +def test_series_to_pandas_categorical_pyarrow(polars_dtype: pl.PolarsDataType) -> None: + s = pl.Series("x", ["a", "b", "a"], dtype=polars_dtype) + result = s.to_pandas(use_pyarrow_extension_array=True) + assert s.to_list() == result.to_list() diff --git a/py-polars/tests/unit/io/cloud/test_aws.py b/py-polars/tests/unit/io/cloud/test_aws.py index 03fb72b7eff4..652be4257658 100644 --- a/py-polars/tests/unit/io/cloud/test_aws.py +++ b/py-polars/tests/unit/io/cloud/test_aws.py @@ -1,5 +1,6 @@ from __future__ import annotations +import multiprocessing from typing import TYPE_CHECKING, Any, Callable, Iterator import boto3 @@ -7,6 +8,7 @@ from moto.server import ThreadedMotoServer import polars as pl +from polars.testing import assert_frame_equal if TYPE_CHECKING: from pathlib import Path @@ -33,12 +35,14 @@ def s3_base(monkeypatch_module: Any) -> Iterator[str]: host = "127.0.0.1" port = 5000 moto_server = ThreadedMotoServer(host, port) - - moto_server.start() + # Start in a separate process to avoid deadlocks + mp = multiprocessing.get_context("spawn") + p = mp.Process(target=moto_server._server_entry, daemon=True) + p.start() print("server up") yield f"http://{host}:{port}" print("moto done") - moto_server.stop() + p.kill() @pytest.fixture() @@ -47,7 +51,7 @@ def s3(s3_base: str, io_files_path: Path) -> str: client = boto3.client("s3", region_name=region, endpoint_url=s3_base) client.create_bucket(Bucket="bucket") - files = ["foods1.csv", "foods1.ipc", "foods1.parquet"] + files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"] for file in files: client.upload_file(io_files_path / file, Bucket="bucket", Key=file) return s3_base @@ -55,10 +59,7 @@ def s3(s3_base: str, io_files_path: Path) -> str: @pytest.mark.parametrize( ("function", "extension"), - [ - (pl.read_csv, "csv"), - (pl.read_ipc, "ipc"), - ], + [(pl.read_csv, "csv"), (pl.read_ipc, "ipc")], ) def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None: df = function( @@ -71,9 +72,7 @@ def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None: @pytest.mark.parametrize( ("function", "extension"), - [ - (pl.scan_ipc, "ipc"), - ], + [(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")], ) def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None: df = function( @@ -82,3 +81,13 @@ def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None: ) assert df.columns == ["category", "calories", "fats_g", "sugars_g"] assert df.collect().shape == (27, 4) + + +def test_lazy_count_s3(s3: str) -> None: + lf = pl.scan_parquet( + "s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3} + ).select(pl.len()) + + assert "FAST COUNT(*)" in lf.explain() + expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32}) + assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/files/example.xls b/py-polars/tests/unit/io/files/example.xls new file mode 100644 index 000000000000..94182083ac23 Binary files /dev/null and b/py-polars/tests/unit/io/files/example.xls differ diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index c6c74f6adbdd..70bce7dfe921 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -243,6 +243,49 @@ def test_csv_missing_utf8_is_empty_string() -> None: ] +def test_csv_int_types() -> None: + f = io.StringIO( + "u8,i8,u16,i16,u32,i32,u64,i64\n" + "0,0,0,0,0,0,0,0\n" + "0,-128,0,-32768,0,-2147483648,0,-9223372036854775808\n" + "255,127,65535,32767,4294967295,2147483647,18446744073709551615,9223372036854775807\n" + "01,01,01,01,01,01,01,01\n" + "01,-01,01,-01,01,-01,01,-01\n" + ) + df = pl.read_csv( + f, + schema={ + "u8": pl.UInt8, + "i8": pl.Int8, + "u16": pl.UInt16, + "i16": pl.Int16, + "u32": pl.UInt32, + "i32": pl.Int32, + "u64": pl.UInt64, + "i64": pl.Int64, + }, + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "u8": pl.Series([0, 0, 255, 1, 1], dtype=pl.UInt8), + "i8": pl.Series([0, -128, 127, 1, -1], dtype=pl.Int8), + "u16": pl.Series([0, 0, 65535, 1, 1], dtype=pl.UInt16), + "i16": pl.Series([0, -32768, 32767, 1, -1], dtype=pl.Int16), + "u32": pl.Series([0, 0, 4294967295, 1, 1], dtype=pl.UInt32), + "i32": pl.Series([0, -2147483648, 2147483647, 1, -1], dtype=pl.Int32), + "u64": pl.Series([0, 0, 18446744073709551615, 1, 1], dtype=pl.UInt64), + "i64": pl.Series( + [0, -9223372036854775808, 9223372036854775807, 1, -1], + dtype=pl.Int64, + ), + } + ), + ) + + def test_csv_float_parsing() -> None: lines_with_floats = [ "123.86,+123.86,-123.86\n", @@ -715,7 +758,7 @@ def test_csv_date_handling() -> None: 1742-03-21 1743-06-16 1730-07-22 - "" + 1739-03-16 """ ) @@ -738,6 +781,67 @@ def test_csv_date_handling() -> None: assert_frame_equal(out, expected) +def test_csv_no_date_dtype_because_string() -> None: + csv = textwrap.dedent( + """\ + date + 2024-01-01 + 2024-01-02 + hello + """ + ) + out = pl.read_csv(csv.encode(), try_parse_dates=True) + assert out.dtypes == [pl.String] + + +def test_csv_infer_date_dtype() -> None: + csv = textwrap.dedent( + """\ + date + 2024-01-01 + "2024-01-02" + + 2024-01-04 + """ + ) + out = pl.read_csv(csv.encode(), try_parse_dates=True) + expected = pl.DataFrame( + { + "date": [ + date(2024, 1, 1), + date(2024, 1, 2), + None, + date(2024, 1, 4), + ] + } + ) + assert_frame_equal(out, expected) + + +def test_csv_date_dtype_ignore_errors() -> None: + csv = textwrap.dedent( + """\ + date + hello + 2024-01-02 + world + !! + """ + ) + out = pl.read_csv(csv.encode(), ignore_errors=True, dtypes={"date": pl.Date}) + expected = pl.DataFrame( + { + "date": [ + None, + date(2024, 1, 2), + None, + None, + ] + } + ) + assert_frame_equal(out, expected) + + def test_csv_globbing(io_files_path: Path) -> None: path = io_files_path / "foods*.csv" df = pl.read_csv(path) @@ -869,6 +973,28 @@ def test_quoting_round_trip() -> None: assert_frame_equal(read_df, df) +def test_csv_field_schema_inference_with_whitespace() -> None: + csv = """\ +bool,bool-,-bool,float,float-,-float,int,int-,-int +true,true , true,1.2,1.2 , 1.2,1,1 , 1 +""" + df = pl.read_csv(io.StringIO(csv), has_header=True) + expected = pl.DataFrame( + { + "bool": [True], + "bool-": ["true "], + "-bool": [" true"], + "float": [1.2], + "float-": ["1.2 "], + "-float": [" 1.2"], + "int": [1], + "int-": ["1 "], + "-int": [" 1"], + } + ) + assert_frame_equal(df, expected) + + def test_fallback_chrono_parser() -> None: data = textwrap.dedent( """\ @@ -1215,7 +1341,8 @@ def test_float_precision(dtype: pl.Float32 | pl.Float64) -> None: def test_skip_rows_different_field_len() -> None: csv = io.StringIO( textwrap.dedent( - """a,b + """\ + a,b 1,A 2, 3,B @@ -1349,6 +1476,11 @@ def test_batched_csv_reader_no_batches(foods_file_path: Path) -> None: assert batches is None +def test_read_csv_batched_invalid_source() -> None: + with pytest.raises(TypeError): + pl.read_csv_batched(source=5) # type: ignore[arg-type] + + def test_csv_single_categorical_null() -> None: f = io.BytesIO() pl.DataFrame( @@ -1514,6 +1646,24 @@ def test_read_csv_n_rows_outside_heuristic() -> None: assert pl.read_csv(f, n_rows=2048, has_header=False).shape == (2048, 4) +def test_read_csv_comments_on_top_with_schema_11667() -> None: + csv = """ +# This is a comment +A,B +1,Hello +2,World +""".strip() + + schema = { + "A": pl.Int32(), + "B": pl.Utf8(), + } + + df = pl.read_csv(io.StringIO(csv), comment_prefix="#", schema=schema) + assert len(df) == 2 + assert df.schema == schema + + def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: # The capsys fixture allows pytest to access stdout/stderr. See # https://docs.pytest.org/en/7.1.x/how-to/capture-stdout-stderr.html @@ -1526,7 +1676,7 @@ def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: ) # pytest hijacks sys.stdout and changes its type, which causes mypy failure - df.write_csv(sys.stdout) # type: ignore[call-overload] + df.write_csv(sys.stdout) captured = capsys.readouterr() assert captured.out == ( "numbers,strings,dates\n" @@ -1535,7 +1685,7 @@ def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: "3,stdout,2023-01-03\n" ) - df.write_csv(sys.stderr) # type: ignore[call-overload] + df.write_csv(sys.stderr) captured = capsys.readouterr() assert captured.err == ( "numbers,strings,dates\n" @@ -1675,7 +1825,7 @@ def test_provide_schema() -> None: } -def test_custom_writeable_object() -> None: +def test_custom_writable_object() -> None: df = pl.DataFrame({"a": [10, 20, 30], "b": ["x", "y", "z"]}) class CustomBuffer: @@ -1781,3 +1931,32 @@ def test_partial_read_compressed_file(tmp_path: Path) -> None: file_path, skip_rows=40, has_header=False, skip_rows_after_header=20, n_rows=30 ) assert df.shape == (30, 3) + + +def test_read_csv_invalid_dtypes() -> None: + csv = textwrap.dedent( + """\ + a,b + 1,foo + 2,bar + 3,baz + """ + ) + f = io.StringIO(csv) + with pytest.raises(TypeError, match="`dtypes` should be of type list or dict"): + pl.read_csv(f, dtypes={pl.Int64, pl.String}) # type: ignore[arg-type] + + +@pytest.mark.parametrize("columns", [["b"], "b"]) +def test_read_csv_single_column(columns: list[str] | str) -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 4,5,6 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, columns=columns) + expected = pl.DataFrame({"b": [2, 5]}) + assert_frame_equal(df, expected) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 7b4ad1d8bc32..9fc68921905d 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -11,6 +11,7 @@ import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast import polars as pl @@ -353,8 +354,13 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: # setup underlying test data tmp_path.mkdir(exist_ok=True) create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - conn = create_engine(f"sqlite:///{test_db}") - t = Table("test_data", MetaData(), autoload_with=conn) + + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{test_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + t = Table("test_data", MetaData(), autoload_with=alchemy_engine) # establish sqlalchemy "selectable" and validate usage selectable_query = select( @@ -363,21 +369,23 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: t.c.value, ).where(t.c.value < 0) - assert_frame_equal( - pl.read_database(selectable_query, connection=conn.connect()), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), - ) + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(selectable_query, connection=conn), + pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), + ) def test_read_database_parameterised(tmp_path: Path) -> None: # setup underlying test data tmp_path.mkdir(exist_ok=True) create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) + alchemy_engine = create_engine(f"sqlite:///{test_db}") # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) - alchemy_conn: ConnectionOrCursor = create_engine(f"sqlite:///{test_db}").connect() - test_conns = (alchemy_conn, raw_conn) + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() # establish parameterised queries and validate usage query = """ @@ -390,7 +398,10 @@ def test_read_database_parameterised(tmp_path: Path) -> None: ("?", (0,)), ("?", [0]), ): - for conn in test_conns: + for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn): + if alchemy_session is conn and param == "?": + continue # alchemy session.execute() doesn't support positional params + assert_frame_equal( pl.read_database( query.format(n=param), diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index 56bf30fd169b..46f097863a98 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -388,7 +388,7 @@ def test_write_delta_with_merge_and_no_table(tmp_path: Path) -> None: def test_write_delta_with_merge(tmp_path: Path) -> None: df = pl.DataFrame({"a": [1, 2, 3]}) - df.write_delta(tmp_path, mode="append") + df.write_delta(tmp_path) merger = df.write_delta( tmp_path, @@ -407,6 +407,7 @@ def test_write_delta_with_merge(tmp_path: Path) -> None: merger.when_matched_delete(predicate="t.a > 2").execute() - table = pl.read_delta(str(tmp_path)) + result = pl.read_delta(str(tmp_path)) - assert_frame_equal(df.filter(pl.col("a") <= 2), table) + expected = df.filter(pl.col("a") <= 2) + assert_frame_equal(result, expected, check_row_order=False) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index ad4145abf8df..67ddef655366 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -176,3 +176,28 @@ def test_hive_partitioned_err(io_files_path: Path, tmp_path: Path) -> None: with pytest.raises(pl.ComputeError, match="invalid hive partitions"): pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True) + + +@pytest.mark.write_disk() +def test_hive_partitioned_projection_skip_files( + io_files_path: Path, tmp_path: Path +) -> None: + # ensure that it makes hive columns even when . in dir value + # and that it doesn't make hive columns from filename with = + df = pl.DataFrame( + {"sqlver": [10012.0, 10013.0], "namespace": ["eos", "fda"], "a": [1, 2]} + ) + root = tmp_path / "partitioned_data" + for dir_tuple, sub_df in df.partition_by( + ["sqlver", "namespace"], include_key=False, as_dict=True + ).items(): + new_path = root / f"sqlver={dir_tuple[0]}" / f"namespace={dir_tuple[1]}" + new_path.mkdir(parents=True, exist_ok=True) + sub_df.write_parquet(new_path / "file=8484.parquet") + test_df = ( + pl.scan_parquet(str(root) + "/**/**/*.parquet") + # don't care about column order + .select("sqlver", "namespace", "a", pl.exclude("sqlver", "namespace", "a")) + .collect() + ) + assert_frame_equal(df, test_df) diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 679ec8842a8a..41b47c135ec8 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -241,3 +241,15 @@ def test_struct_nested_enum() -> None: df.write_ipc(buffer) df = pl.read_ipc(buffer) assert df.get_column("struct_cat").dtype == dtype + + +@pytest.mark.slow() +def test_ipc_view_gc_14448() -> None: + f = io.BytesIO() + # This size was required to trigger the bug + df = pl.DataFrame( + pl.Series(["small"] * 10 + ["looooooong string......."] * 750).slice(20, 20) + ) + df.write_ipc(f, future=True) + f.seek(0) + assert_frame_equal(pl.read_ipc(f), df) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index c937d0fe140e..74e38c3c186e 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -284,10 +284,10 @@ def test_write_json_duration() -> None: ) } ) - assert ( - df.write_json(row_oriented=True) - == '[{"a":"P1DT5362.939S"},{"a":"P1DT5362.890S"},{"a":"PT6020.836S"}]' - ) + + # we don't guarantee a format, just round-circling + value = str(df.write_json(row_oriented=True)) + assert value == """[{"a":"PT91762.939S"},{"a":"PT91762.89S"},{"a":"PT6020.836S"}]""" @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/io/test_lazy_count_star.py b/py-polars/tests/unit/io/test_lazy_count_star.py new file mode 100644 index 000000000000..7ab69ad73aab --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_count_star.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +from tempfile import NamedTemporaryFile + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)] +) +def test_count_csv(io_files_path: Path, path: str, n_rows: int) -> None: + lf = pl.scan_csv(io_files_path / path).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.write_disk() +def test_commented_csv() -> None: + csv_a = NamedTemporaryFile() + csv_a.write( + b""" +A,B +Gr1,A +Gr1,B +# comment line + """.strip() + ) + csv_a.seek(0) + + expected = pl.DataFrame(pl.Series("len", [2], dtype=pl.UInt32)) + lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.parametrize( + ("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)] +) +def test_count_parquet(io_files_path: Path, pattern: str, n_rows: int) -> None: + lf = pl.scan_parquet(io_files_path / pattern).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)] +) +def test_count_ipc(io_files_path: Path, path: str, n_rows: int) -> None: + lf = pl.scan_ipc(io_files_path / path).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/test_lazy_csv.py b/py-polars/tests/unit/io/test_lazy_csv.py index 22e57462ae49..59bb84d72658 100644 --- a/py-polars/tests/unit/io/test_lazy_csv.py +++ b/py-polars/tests/unit/io/test_lazy_csv.py @@ -24,7 +24,7 @@ def test_scan_csv(io_files_path: Path) -> None: def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None: - dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.threadpool_size() + 1) + dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1) pl.concat(dfs, parallel=True).collect(comm_subplan_elim=False) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 36d13e12f0e2..f30045996793 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -709,3 +709,24 @@ def test_utc_timezone_normalization_13670(tmp_path: Path) -> None: assert cast(pl.Datetime, df.schema["c1"]).time_zone == "UTC" df = pl.scan_parquet([zero_path, utc_path]).head(5).collect() assert cast(pl.Datetime, df.schema["c1"]).time_zone == "UTC" + + +def test_parquet_rle_14333() -> None: + vals = [True, False, True, False, True, False, True, False, True, False] + table = pa.table({"a": vals}) + + f = io.BytesIO() + pq.write_table(table, f, data_page_version="2.0") + f.seek(0) + assert pl.read_parquet(f)["a"].to_list() == vals + + +def test_parquet_rle_null_binary_read_14638() -> None: + df = pl.DataFrame({"x": [None]}, schema={"x": pl.String}) + + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=True) + f.seek(0) + assert "RLE_DICTIONARY" in pq.read_metadata(f).row_group(0).column(0).encodings + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 91ee0d634145..085b50cb6a01 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -1,10 +1,10 @@ from __future__ import annotations -import sys import warnings from collections import OrderedDict from datetime import date, datetime from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import pytest @@ -12,21 +12,39 @@ import polars as pl import polars.selectors as cs from polars.exceptions import NoDataError, ParameterCollisionError +from polars.io.spreadsheet.functions import _identify_workbook from polars.testing import assert_frame_equal if TYPE_CHECKING: - from pathlib import Path - from polars.type_aliases import ExcelSpreadsheetEngine, SchemaDict, SelectorType pytestmark = pytest.mark.slow() +@pytest.fixture() +def path_xls(io_files_path: Path) -> Path: + # old excel 97-2004 format + return io_files_path / "example.xls" + + @pytest.fixture() def path_xlsx(io_files_path: Path) -> Path: + # modern excel format return io_files_path / "example.xlsx" +@pytest.fixture() +def path_xlsb(io_files_path: Path) -> Path: + # excel binary format + return io_files_path / "example.xlsb" + + +@pytest.fixture() +def path_ods(io_files_path: Path) -> Path: + # open document spreadsheet + return io_files_path / "example.ods" + + @pytest.fixture() def path_xlsx_empty(io_files_path: Path) -> Path: return io_files_path / "empty.xlsx" @@ -37,11 +55,6 @@ def path_xlsx_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.xlsx" -@pytest.fixture() -def path_xlsb(io_files_path: Path) -> Path: - return io_files_path / "example.xlsb" - - @pytest.fixture() def path_xlsb_empty(io_files_path: Path) -> Path: return io_files_path / "empty.xlsb" @@ -52,11 +65,6 @@ def path_xlsb_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.xlsb" -@pytest.fixture() -def path_ods(io_files_path: Path) -> Path: - return io_files_path / "example.ods" - - @pytest.fixture() def path_ods_empty(io_files_path: Path) -> Path: return io_files_path / "empty.ods" @@ -70,24 +78,16 @@ def path_ods_mixed(io_files_path: Path) -> Path: @pytest.mark.parametrize( ("read_spreadsheet", "source", "engine_params"), [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + (pl.read_excel, "path_xls", {"engine": None}), # << autodetect # xlsx file (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + (pl.read_excel, "path_xlsx", {"engine": None}), # << autodetect (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), - pytest.param( - *(pl.read_excel, "path_xlsx", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), # xlsb file (binary) - pytest.param( - *(pl.read_excel, "path_xlsb", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), # open document (pl.read_ods, "path_ods", {}), @@ -118,24 +118,14 @@ def test_read_spreadsheet( @pytest.mark.parametrize( ("read_spreadsheet", "source", "params"), [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), # xlsx file (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), - pytest.param( - *(pl.read_excel, "path_xlsx", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), # xlsb file (binary) - pytest.param( - *(pl.read_excel, "path_xlsb", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), # open document (pl.read_ods, "path_ods", {}), @@ -173,24 +163,14 @@ def test_read_excel_multi_sheets( @pytest.mark.parametrize( ("read_spreadsheet", "source", "params"), [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), # xlsx file (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), - pytest.param( - *(pl.read_excel, "path_xlsx", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), # xlsb file (binary) - pytest.param( - *(pl.read_excel, "path_xlsb", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), # open document (pl.read_ods, "path_ods", {}), @@ -231,13 +211,7 @@ def test_read_excel_all_sheets( ("engine", "schema_overrides"), [ ("xlsx2csv", {"datetime": pl.Datetime}), - pytest.param( - *("calamine", None), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + ("calamine", None), ("openpyxl", None), ], ) @@ -272,24 +246,14 @@ def test_read_excel_basic_datatypes( @pytest.mark.parametrize( ("read_spreadsheet", "source", "params"), [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), # xlsx file (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), - pytest.param( - *(pl.read_excel, "path_xlsx", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), # xlsb file (binary) - pytest.param( - *(pl.read_excel, "path_xlsb", {"engine": "calamine"}), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), # open document (pl.read_ods, "path_ods", {}), @@ -373,20 +337,7 @@ def test_read_mixed_dtype_columns( ) -@pytest.mark.parametrize( - "engine", - [ - "xlsx2csv", - "openpyxl", - pytest.param( - "calamine", - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), - ], -) +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_write_excel_bytes(engine: ExcelSpreadsheetEngine) -> None: df = pl.DataFrame({"A": [1.5, -2, 0, 3.0, -4.5, 5.0]}) @@ -403,6 +354,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, ).drop_nulls() + assert df1.schema["cardinality"] == pl.UInt16 assert df1.schema["rows_by_key"] == pl.Float64 assert df1.schema["iter_groups"] == pl.Float64 @@ -410,8 +362,9 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N df2 = pl.read_excel( path_xlsx, sheet_name="test4", - read_csv_options={"dtypes": {"cardinality": pl.UInt16}}, + read_options={"dtypes": {"cardinality": pl.UInt16}}, ).drop_nulls() + assert df2.schema["cardinality"] == pl.UInt16 assert df2.schema["rows_by_key"] == pl.Float64 assert df2.schema["iter_groups"] == pl.Float64 @@ -420,19 +373,23 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N path_xlsx, sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, - read_csv_options={ + read_options={ "dtypes": { "rows_by_key": pl.Float32, "iter_groups": pl.Float32, }, }, ).drop_nulls() + assert df3.schema["cardinality"] == pl.UInt16 assert df3.schema["rows_by_key"] == pl.Float32 assert df3.schema["iter_groups"] == pl.Float32 for workbook_path in (path_xlsx, path_xlsb, path_ods): - df4 = pl.read_excel( + read_spreadsheet = ( + pl.read_ods if workbook_path.suffix == ".ods" else pl.read_excel + ) + df4 = read_spreadsheet( # type: ignore[operator] workbook_path, sheet_name="test5", schema_overrides={"dtm": pl.Datetime("ns"), "dt": pl.Date}, @@ -453,12 +410,12 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N ) with pytest.raises(ParameterCollisionError): - # cannot specify 'cardinality' in both schema_overrides and read_csv_options + # cannot specify 'cardinality' in both schema_overrides and read_options pl.read_excel( path_xlsx, sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, - read_csv_options={"dtypes": {"cardinality": pl.Int32}}, + read_options={"dtypes": {"cardinality": pl.Int32}}, ) # read multiple sheets in conjunction with 'schema_overrides' @@ -492,20 +449,7 @@ def test_unsupported_binary_workbook(path_xlsx: Path, path_xlsb: Path) -> None: pl.read_excel(path_xlsb, engine="openpyxl") -@pytest.mark.parametrize( - "engine", - [ - "xlsx2csv", - "openpyxl", - pytest.param( - "calamine", - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), - ], -) +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_read_excel_all_sheets_with_sheet_name(path_xlsx: Path, engine: str) -> None: with pytest.raises( ValueError, @@ -625,45 +569,39 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: "val": [100.5, 55.0, -99.5], } ) - header_opts = ( - {} - if write_params.get("include_header", True) - else {"has_header": False, "new_columns": ["dtm", "str", "val"]} - ) - fmt_strptime = "%Y-%m-%d" - if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": - fmt_strptime = "%d-%m-%Y" - # write to an xlsx with polars, using various parameters... - xls = BytesIO() - _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) + engine: ExcelSpreadsheetEngine + for engine in ("calamine", "xlsx2csv"): # type: ignore[assignment] + table_params = ( + {} + if write_params.get("include_header", True) + else ( + {"has_header": False, "new_columns": ["dtm", "str", "val"]} + if engine == "xlsx2csv" + else {"header_row": None, "column_names": ["dtm", "str", "val"]} + ) + ) + fmt_strptime = "%Y-%m-%d" + if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": + fmt_strptime = "%d-%m-%Y" - # ...and read it back again: - xldf = pl.read_excel( - xls, - sheet_name="data", - read_csv_options=header_opts, - )[:3] - xldf = xldf.select(xldf.columns[:3]).with_columns( - pl.col("dtm").str.strptime(pl.Date, fmt_strptime) - ) - assert_frame_equal(df, xldf) + # write to an xlsx with polars, using various parameters... + xls = BytesIO() + _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) + # ...and read it back again: + xldf = pl.read_excel( + xls, + sheet_name="data", + engine=engine, + read_options=table_params, + )[:3].select(df.columns[:3]) + if engine == "xlsx2csv": + xldf = xldf.with_columns(pl.col("dtm").str.strptime(pl.Date, fmt_strptime)) + assert_frame_equal(df, xldf) -@pytest.mark.parametrize( - "engine", - [ - "xlsx2csv", - "openpyxl", - pytest.param( - "calamine", - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), - ], -) + +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_excel_compound_types( engine: ExcelSpreadsheetEngine, ) -> None: @@ -682,20 +620,7 @@ def test_excel_compound_types( ] -@pytest.mark.parametrize( - "engine", - [ - "xlsx2csv", - "openpyxl", - pytest.param( - "calamine", - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), - ], -) +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_excel_sparklines(engine: ExcelSpreadsheetEngine) -> None: from xlsxwriter import Workbook @@ -851,10 +776,15 @@ def test_excel_empty_sheet( request: pytest.FixtureRequest, ) -> None: empty_spreadsheet_path = request.getfixturevalue(source) + read_spreadsheet = ( + pl.read_ods # type: ignore[assignment] + if empty_spreadsheet_path.suffix == ".ods" + else pl.read_excel + ) with pytest.raises(NoDataError, match="empty Excel sheet"): - pl.read_excel(empty_spreadsheet_path) + read_spreadsheet(empty_spreadsheet_path) - df = pl.read_excel(empty_spreadsheet_path, raise_if_empty=False) + df = read_spreadsheet(empty_spreadsheet_path, raise_if_empty=False) assert_frame_equal(df, pl.DataFrame()) @@ -863,13 +793,7 @@ def test_excel_empty_sheet( [ ("xlsx2csv", ["a"]), ("openpyxl", ["a", "b"]), - pytest.param( - *("calamine", ["a", "b"]), - marks=pytest.mark.skipif( - sys.platform == "win32", - reason="fastexcel not yet available on Windows", - ), - ), + ("calamine", ["a", "b"]), ("xlsx2csv", cs.numeric()), ("openpyxl", cs.last()), ], @@ -887,11 +811,64 @@ def test_excel_hidden_columns( assert_frame_equal(df, read_df) -def test_invalid_engine_options() -> None: - # read_csv_options only applicable with 'xlsx2csv' engine - with pytest.raises(ValueError, match="cannot specify `read_csv_options`"): - pl.read_excel( - "", - engine="openpyxl", - read_csv_options={"sep": "\t"}, - ) +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) +def test_excel_type_inference_with_nulls(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + { + "a": [1, 2, None], + "b": [1.0, None, 3.5], + "c": ["x", None, "z"], + "d": [True, False, None], + "e": [date(2023, 1, 1), None, date(2023, 1, 4)], + "f": [ + datetime(2023, 1, 1), + datetime(2000, 10, 10, 10, 10), + None, + ], + } + ) + xls = BytesIO() + df.write_excel(xls) + + read_df = pl.read_excel( + xls, + engine=engine, + schema_overrides={ + "e": pl.Date, + "f": pl.Datetime("us"), + }, + ) + assert_frame_equal(df, read_df) + + +@pytest.mark.parametrize( + ("path", "file_type"), + [ + ("path_xls", "xls"), + ("path_xlsx", "xlsx"), + ("path_xlsb", "xlsb"), + ], +) +def test_identify_workbook( + path: str, file_type: str, request: pytest.FixtureRequest +) -> None: + # identify from file path + spreadsheet_path = request.getfixturevalue(path) + assert _identify_workbook(spreadsheet_path) == file_type + + # note that we can't distinguish between xlsx and xlsb + # from the magic bytes block alone (so we default to xlsx) + if file_type == "xlsb": + file_type = "xlsx" + + # identify from IO[bytes] + with Path.open(spreadsheet_path, "rb") as f: + assert _identify_workbook(f) == file_type + + # identify from bytes + with Path.open(spreadsheet_path, "rb") as f: + assert _identify_workbook(f.read()) == file_type + + # identify from BytesIO + with Path.open(spreadsheet_path, "rb") as f: + assert _identify_workbook(BytesIO(f.read())) == file_type diff --git a/py-polars/tests/unit/lazyframe/test_tree_format.py b/py-polars/tests/unit/lazyframe/test_tree_format.py new file mode 100644 index 000000000000..7ceb31fa5acc --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_tree_format.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import polars as pl + + +def test_logical_plan_tree_format() -> None: + lf = ( + pl.LazyFrame( + { + "foo": [1, 2, 3], + "bar": [6, 7, 8], + "ham": ["a", "b", "c"], + } + ) + .select(foo=pl.col("foo") + 1, bar=pl.col("bar") + 2) + .select( + threshold=pl.when(pl.col("foo") + pl.col("bar") > 2).then(10).otherwise(0) + ) + ) + + expected = """ + SELECT [.when([([(col("foo")) + (col("bar"))]) > (2)]).then(10).otherwise(0).alias("threshold")] FROM + SELECT [[(col("foo")) + (1)].alias("foo"), [(col("bar")) + (2)].alias("bar")] FROM + DF ["foo", "bar", "ham"]; PROJECT 2/3 COLUMNS; SELECTION: "None" +""" + assert lf.explain().strip() == expected.strip() + + expected = """ + 0 1 2 3 + ┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + │ + │ ╭────────╮ + 0 │ │ SELECT │ + │ ╰───┬┬───╯ + │ ││ + │ │╰─────────────────────────────────────╮ + │ │ │ + │ ╭───────────────────────┴────────────────────────╮ │ + │ │ expression: │ ╭───┴────╮ + │ │ .when([([(col("foo")) + (col("bar"))]) > (2)]) │ │ FROM: │ + 1 │ │ .then(10) │ │ SELECT │ + │ │ .otherwise(0) │ ╰───┬┬───╯ + │ │ .alias("threshold") │ ││ + │ ╰────────────────────────────────────────────────╯ ││ + │ ││ + │ │╰────────────────────────┬───────────────────────────╮ + │ │ │ │ + │ ╭──────────┴───────────╮ ╭──────────┴───────────╮ ╭────────────┴─────────────╮ + │ │ expression: │ │ expression: │ │ FROM: │ + 2 │ │ [(col("foo")) + (1)] │ │ [(col("bar")) + (2)] │ │ DF ["foo", "bar", "ham"] │ + │ │ .alias("foo") │ │ .alias("bar") │ │ PROJECT 2/3 COLUMNS │ + │ ╰──────────────────────╯ ╰──────────────────────╯ ╰──────────────────────────╯ +""" + assert lf.explain(tree_format=True).strip() == expected.strip() diff --git a/py-polars/tests/unit/meta/__init__.py b/py-polars/tests/unit/meta/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/meta/test_build.py b/py-polars/tests/unit/meta/test_build.py new file mode 100644 index 000000000000..3ac048ffa193 --- /dev/null +++ b/py-polars/tests/unit/meta/test_build.py @@ -0,0 +1,26 @@ +import polars as pl + + +def test_build_info_version() -> None: + build_info = pl.build_info() + assert build_info["version"] == pl.__version__ + + +def test_build_info_keys() -> None: + build_info = pl.build_info() + expected_keys = [ + "build", + "info-time", + "dependencies", + "features", + "host", + "target", + "git", + "version", + ] + assert sorted(build_info.keys()) == sorted(expected_keys) + + +def test_build_info_features() -> None: + build_info = pl.build_info() + assert "BUILD_INFO" in build_info["features"] diff --git a/py-polars/tests/unit/meta/test_index_type.py b/py-polars/tests/unit/meta/test_index_type.py new file mode 100644 index 000000000000..07bc112b3dcd --- /dev/null +++ b/py-polars/tests/unit/meta/test_index_type.py @@ -0,0 +1,5 @@ +import polars as pl + + +def test_get_index_type() -> None: + assert pl.get_index_type() == pl.UInt32() diff --git a/py-polars/tests/unit/meta/test_thread_pool.py b/py-polars/tests/unit/meta/test_thread_pool.py new file mode 100644 index 000000000000..159ab89cc946 --- /dev/null +++ b/py-polars/tests/unit/meta/test_thread_pool.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_thread_pool_size() -> None: + result = pl.thread_pool_size() + assert isinstance(result, int) + + +def test_threadpool_size_deprecated() -> None: + with pytest.deprecated_call(): + result = pl.threadpool_size() + assert isinstance(result, int) diff --git a/py-polars/tests/unit/utils/test_show_versions.py b/py-polars/tests/unit/meta/test_versions.py similarity index 100% rename from py-polars/tests/unit/utils/test_show_versions.py rename to py-polars/tests/unit/meta/test_versions.py diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/namespaces/array/test_array.py index 7e04b5e04546..4486b90eeddf 100644 --- a/py-polars/tests/unit/namespaces/array/test_array.py +++ b/py-polars/tests/unit/namespaces/array/test_array.py @@ -3,7 +3,6 @@ import datetime from typing import Any -import numpy as np import pytest import polars as pl @@ -70,11 +69,6 @@ def test_arr_unique() -> None: assert_frame_equal(out, expected) -def test_array_to_numpy() -> None: - s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(pl.Int64, 2)) - assert (s.to_numpy() == np.array([[1, 2], [3, 4], [5, 6]])).all() - - def test_array_any_all() -> None: s = pl.Series( [[True, True], [False, True], [False, False], [None, None], None], @@ -329,3 +323,46 @@ def test_array_count_matches( df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)}) out = df.select(count_matches=pl.col("arr").arr.count_matches(data)) assert out.to_dict(as_series=False) == {"count_matches": expected} + + +def test_array_to_struct() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [4, 5, None]]}, schema={"a": pl.Array(pl.Int8, 3)} + ) + assert df.select([pl.col("a").arr.to_struct()]).to_series().to_list() == [ + {"field_0": 1, "field_1": 2, "field_2": 3}, + {"field_0": 4, "field_1": 5, "field_2": None}, + ] + + df = pl.DataFrame( + {"a": [[1, 2, None], [1, 2, 3]]}, schema={"a": pl.Array(pl.Int8, 3)} + ) + assert df.select( + [pl.col("a").arr.to_struct(fields=lambda idx: f"col_name_{idx}")] + ).to_series().to_list() == [ + {"col_name_0": 1, "col_name_1": 2, "col_name_2": None}, + {"col_name_0": 1, "col_name_1": 2, "col_name_2": 3}, + ] + + assert df.lazy().select(pl.col("a").arr.to_struct()).unnest( + "a" + ).sum().collect().columns == ["field_0", "field_1", "field_2"] + + +def test_array_shift() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], None, [4, 5, 6], [7, 8, 9]], "n": [None, 1, 1, -2]}, + schema={"a": pl.Array(pl.Int64, 3), "n": pl.Int64}, + ) + + out = df.select( + lit=pl.col("a").arr.shift(1), expr=pl.col("a").arr.shift(pl.col("n")) + ) + expected = pl.DataFrame( + { + "lit": [[None, 1, 2], None, [None, 4, 5], [None, 7, 8]], + "expr": [None, None, [None, 4, 5], [9, None, None]], + }, + schema={"lit": pl.Array(pl.Int64, 3), "expr": pl.Array(pl.Int64, 3)}, + ) + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/files/test_tree_fmt.txt b/py-polars/tests/unit/namespaces/files/test_tree_fmt.txt index 1c7fa9659c50..c3a4f4b23c53 100644 --- a/py-polars/tests/unit/namespaces/files/test_tree_fmt.txt +++ b/py-polars/tests/unit/namespaces/files/test_tree_fmt.txt @@ -1,66 +1,95 @@ (pl.col("foo") * pl.col("bar")).sum().over("ham", "ham2") / 2 - 0 1 2 3 4 - ┌───────────────────────────────────────────────────────────────────────────────── + 0 1 2 3 4 + ┌───────────────────────────────────────────────────────────────────────── │ - │ ╭───────────╮ - 0 │ │ binary: / │ - │ ╰───────────╯ - │ │ ╰─────────────╮ - │ │ │ - │ │ │ - │ ╭────────╮ ╭────────╮ - 1 │ │ lit(2) │ │ window │ - │ ╰────────╯ ╰────────╯ - │ │ ╰──────────────╮───────────────╮ - │ │ │ │ - │ │ │ │ - │ ╭───────────╮ ╭──────────╮ ╭─────╮ - 2 │ │ col(ham2) │ │ col(ham) │ │ sum │ - │ ╰───────────╯ ╰──────────╯ ╰─────╯ - │ │ - │ │ - │ │ - │ ╭───────────╮ - 3 │ │ binary: * │ - │ ╰───────────╯ - │ │ ╰──────────────╮ - │ │ │ - │ │ │ - │ ╭──────────╮ ╭──────────╮ - 4 │ │ col(bar) │ │ col(foo) │ - │ ╰──────────╯ ╰──────────╯ + │ ╭───────────╮ + 0 │ │ binary: / │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰─────────────╮ + │ │ │ + │ ╭───┴────╮ ╭───┴────╮ + 1 │ │ lit(2) │ │ window │ + │ ╰────────╯ ╰───┬┬───╯ + │ ││ + │ │╰────────────┬──────────────╮ + │ │ │ │ + │ ╭─────┴─────╮ ╭────┴─────╮ ╭──┴──╮ + 2 │ │ col(ham2) │ │ col(ham) │ │ sum │ + │ ╰───────────╯ ╰──────────╯ ╰──┬──╯ + │ │ + │ │ + │ │ + │ ╭─────┴─────╮ + 3 │ │ binary: * │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭────┴─────╮ + 4 │ │ col(bar) │ │ col(foo) │ + │ ╰──────────╯ ╰──────────╯ --- (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2 - 0 1 2 3 - ┌──────────────────────────────────────────────────────────────── + 0 1 2 3 + ┌────────────────────────────────────────────────────────── │ - │ ╭───────────╮ - 0 │ │ binary: / │ - │ ╰───────────╯ - │ │ ╰─────────────╮ - │ │ │ - │ │ │ - │ ╭────────╮ ╭────────╮ - 1 │ │ lit(2) │ │ window │ - │ ╰────────╯ ╰────────╯ - │ │ ╰─────────────╮ - │ │ │ - │ │ │ - │ ╭──────────╮ ╭─────╮ - 2 │ │ col(ham) │ │ sum │ - │ ╰──────────╯ ╰─────╯ - │ │ - │ │ - │ │ - │ ╭───────────╮ - 3 │ │ binary: * │ - │ ╰───────────╯ - │ │ ╰──────────────╮ - │ │ │ - │ │ │ - │ ╭──────────╮ ╭──────────╮ - 4 │ │ col(bar) │ │ col(foo) │ - │ ╰──────────╯ ╰──────────╯ + │ ╭───────────╮ + 0 │ │ binary: / │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────╮ + │ │ │ + │ ╭───┴────╮ ╭───┴────╮ + 1 │ │ lit(2) │ │ window │ + │ ╰────────╯ ╰───┬┬───╯ + │ ││ + │ │╰─────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭──┴──╮ + 2 │ │ col(ham) │ │ sum │ + │ ╰──────────╯ ╰──┬──╯ + │ │ + │ │ + │ │ + │ ╭─────┴─────╮ + 3 │ │ binary: * │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭────┴─────╮ + 4 │ │ col(bar) │ │ col(foo) │ + │ ╰──────────╯ ╰──────────╯ + +--- +(pl.col("a") + pl.col("b"))**2 + pl.int_range(3) + + 0 1 2 3 4 + ┌─────────────────────────────────────────────────────────────────────────────────── + │ + │ ╭───────────╮ + 0 │ │ binary: + │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────────────────────────╮ + │ │ │ + │ ╭──────────┴──────────╮ ╭───────┴───────╮ + 1 │ │ function: int_range │ │ function: pow │ + │ ╰──────────┬┬─────────╯ ╰───────┬┬──────╯ + │ ││ ││ + │ │╰────────────────╮ │╰───────────────╮ + │ │ │ │ │ + │ ╭───┴────╮ ╭───┴────╮ ╭───┴────╮ ╭─────┴─────╮ + 2 │ │ lit(3) │ │ lit(0) │ │ lit(2) │ │ binary: + │ + │ ╰────────╯ ╰────────╯ ╰────────╯ ╰─────┬┬────╯ + │ ││ + │ │╰───────────╮ + │ │ │ + │ ╭───┴────╮ ╭───┴────╮ + 3 │ │ col(b) │ │ col(a) │ + │ ╰────────╯ ╰────────╯ + diff --git a/py-polars/tests/unit/namespaces/list/__init__.py b/py-polars/tests/unit/namespaces/list/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py similarity index 87% rename from py-polars/tests/unit/namespaces/test_list.py rename to py-polars/tests/unit/namespaces/list/test_list.py index edd7609a2ade..97bf07b634d2 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -290,6 +290,16 @@ def test_list_eval_dtype_inference() -> None: ] +def test_list_eval_categorical() -> None: + df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)}) + df = df.select( + pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null())) + ) + assert_series_equal( + df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical)) + ) + + def test_list_ternary_concat() -> None: df = pl.DataFrame( { @@ -607,135 +617,6 @@ def test_list_count_matches_boolean_nulls_9141() -> None: assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] -def test_list_set_oob() -> None: - df = pl.DataFrame({"a": [42, 23]}) - assert df.select(pl.col("a").list.set_intersection([])).to_dict( - as_series=False - ) == {"a": [[], []]} - - -def test_list_set_operations_float() -> None: - df = pl.DataFrame( - {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}, - schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)}, - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 12.0], - [4.0], - ] - assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ - [1.0, 2.0], - [1.0], - [4.0], - ] - assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ - [3.0], - [], - [], - ] - assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ - [4.0], - [2.0, 12.0], - [], - ] - - -def test_list_set_operations() -> None: - df = pl.DataFrame( - {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]} - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - [1, 2, 3, 4], - [1, 2, 12], - [4], - ] - assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ - [1, 2], - [1], - [4], - ] - assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ - [3], - [], - [], - ] - assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ - [4], - [2, 12], - [], - ] - - # check logical types - dtype = pl.List(pl.Date) - assert ( - df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[ - "b" - ].dtype - == dtype - ) - - df = pl.DataFrame( - { - "a": [["a", "b", "c"], ["b", "e", "z"]], - "b": [["b", "s", "a"], ["a", "e", "f"]], - } - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - ["a", "b", "c", "s"], - ["b", "e", "z", "a", "f"], - ] - - df = pl.DataFrame( - { - "a": [[2, 3, 3], [3, 1], [1, 2, 3]], - "b": [[2, 3, 4], [3, 3, 1], [3, 3]], - } - ) - r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list() - r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list() - exp = [[2, 3], [3, 1], [3]] - assert r1 == exp - assert r2 == exp - - -def test_list_set_operations_broadcast() -> None: - df = pl.DataFrame( - { - "a": [[2, 3, 3], [3, 1], [1, 2, 3]], - } - ) - - assert df.with_columns( - pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]} - assert df.with_columns( - pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]} - assert df.with_columns( - pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[3], [3], [3]]} - assert df.with_columns( - pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a") - ).to_dict(as_series=False) == {"a": [[1], [2], []]} - - -def test_list_set_operation_different_length_chunk_12734() -> None: - df = pl.DataFrame( - { - "a": [[2, 3, 3], [4, 1], [1, 2, 3]], - } - ) - - df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False) - assert df.with_columns( - pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[3], [4], [3]]} - - def test_list_gather_oob_10079() -> None: df = pl.DataFrame( { @@ -879,3 +760,45 @@ def test_list_eval_gater_every_13410() -> None: out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2))) expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]}) assert_frame_equal(out, expected) + + +def test_list_gather_every() -> None: + df = pl.DataFrame( + { + "lst": [[1, 2, 3], [], [4, 5], None, [6, 7, 8], [9, 10, 11, 12]], + "n": [2, 2, 1, 3, None, 2], + "offset": [None, 1, 0, 1, 2, 2], + } + ) + + out = df.select( + n_expr=pl.col("lst").list.gather_every(pl.col("n"), 0), + offset_expr=pl.col("lst").list.gather_every(2, pl.col("offset")), + all_expr=pl.col("lst").list.gather_every(pl.col("n"), pl.col("offset")), + all_lit=pl.col("lst").list.gather_every(2, 0), + ) + + expected = pl.DataFrame( + { + "n_expr": [[1, 3], [], [4, 5], None, None, [9, 11]], + "offset_expr": [None, [], [4], None, [8], [11]], + "all_expr": [None, [], [4, 5], None, None, [11]], + "all_lit": [[1, 3], [], [4], None, [6, 8], [9, 11]], + } + ) + + assert_frame_equal(out, expected) + + +def test_list_n_unique() -> None: + df = pl.DataFrame( + { + "a": [[1, 1, 2], [3, 3], [None], None, []], + } + ) + + out = df.select(n_unique=pl.col("a").list.n_unique()) + expected = pl.DataFrame( + {"n_unique": [2, 1, 1, None, 0]}, schema={"n_unique": pl.UInt32} + ) + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/list/test_set_operations.py b/py-polars/tests/unit/namespaces/list/test_set_operations.py new file mode 100644 index 000000000000..8082b33391cf --- /dev/null +++ b/py-polars/tests/unit/namespaces/list/test_set_operations.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_list_set_oob() -> None: + df = pl.DataFrame({"a": [[42], [23]]}) + result = df.select(pl.col("a").list.set_intersection([])) + assert result.to_dict(as_series=False) == {"a": [[], []]} + + +def test_list_set_operations_float() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}, + schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 12.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1.0, 2.0], + [1.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3.0], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4.0], + [2.0, 12.0], + [], + ] + + +def test_list_set_operations() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]} + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1, 2, 3, 4], + [1, 2, 12], + [4], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1, 2], + [1], + [4], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4], + [2, 12], + [], + ] + + # check logical types + dtype = pl.List(pl.Date) + assert ( + df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[ + "b" + ].dtype + == dtype + ) + + df = pl.DataFrame( + { + "a": [["a", "b", "c"], ["b", "e", "z"]], + "b": [["b", "s", "a"], ["a", "e", "f"]], + } + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + ["a", "b", "c", "s"], + ["b", "e", "z", "a", "f"], + ] + + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + "b": [[2, 3, 4], [3, 3, 1], [3, 3]], + } + ) + r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list() + r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list() + exp = [[2, 3], [3, 1], [3]] + assert r1 == exp + assert r2 == exp + + +def test_list_set_operations_broadcast() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + } + ) + + assert df.with_columns( + pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]} + assert df.with_columns( + pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]} + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [3], [3]]} + assert df.with_columns( + pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a") + ).to_dict(as_series=False) == {"a": [[1], [2], []]} + + +def test_list_set_operation_different_length_chunk_12734() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [4, 1], [1, 2, 3]], + } + ) + + df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False) + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [4], [3]]} + + +def test_list_set_operations_binary() -> None: + df = pl.DataFrame( + { + "a": [[b"1", b"2", b"3"], [b"1", b"1", b"1"], [b"4"]], + "b": [[b"4", b"2", b"1"], [b"2", b"1", b"12"], [b"4"]], + }, + schema={"a": pl.List(pl.Binary), "b": pl.List(pl.Binary)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [b"1", b"2", b"3", b"4"], + [b"1", b"2", b"12"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [b"1", b"2"], + [b"1"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [b"3"], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [b"4"], + [b"2", b"12"], + [], + ] + + +def test_set_operations_14290() -> None: + df = pl.DataFrame( + { + "a": [[1, 2], [2, 3]], + "b": [None, [1, 2]], + } + ) + + out = df.with_columns(pl.col("a").shift(1).alias("shifted_a")).select( + b_dif_a=pl.col("b").list.set_difference("a"), + shifted_a_dif_a=pl.col("shifted_a").list.set_difference("a"), + ) + expected = pl.DataFrame({"b_dif_a": [None, [1]], "shifted_a_dif_a": [None, [1]]}) + assert_frame_equal(out, expected) + + +def test_broadcast_sliced() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}) + out = df.select( + pl.col("a").list.set_difference(pl.Series([[1], [2, 3, 4]]).slice(0, 1)) + ) + expected = pl.DataFrame({"a": [[2], [3, 4]]}) + + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index ff56ca59a536..fd733b228b69 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -20,7 +20,7 @@ from backports.zoneinfo._zoneinfo import ZoneInfo if TYPE_CHECKING: - from polars.type_aliases import TimeUnit + from polars.type_aliases import TemporalLiteral, TimeUnit @pytest.fixture() @@ -906,6 +906,11 @@ def test_year_empty_df() -> None: assert df.select(pl.col("date").dt.year()).dtypes == [pl.Int32] +def test_epoch_invalid() -> None: + with pytest.raises(InvalidOperationError, match="not supported for dtype"): + pl.Series([timedelta(1)]).dt.epoch() + + @pytest.mark.parametrize( "time_unit", ["ms", "us", "ns"], @@ -926,21 +931,36 @@ def test_weekday(time_unit: TimeUnit) -> None: ([date(2022, 1, 1)], date(2022, 1, 1)), ([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)], date(2022, 1, 2)), ([date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)], date(2022, 1, 2)), + ([datetime(2022, 1, 1)], datetime(2022, 1, 1)), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + datetime(2022, 1, 2), + ), ( [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], datetime(2022, 1, 2), ), + ([timedelta(days=1)], timedelta(days=1)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=2)), ], ids=[ "empty", "Nones", - "single", - "spread_even", - "spread_skewed", - "spread_skewed_dt", + "single_date", + "spread_even_date", + "spread_skewed_date", + "single_datetime", + "spread_even_datetime", + "spread_skewed_datetime", + "single_dur", + "spread_even_dur", + "spread_skewed_dur", ], ) -def test_median(values: list[date | None], expected_median: date | None) -> None: +def test_median( + values: list[TemporalLiteral | None], expected_median: TemporalLiteral | None +) -> None: s = pl.Series(values) assert s.dt.median() == expected_median @@ -956,22 +976,35 @@ def test_median(values: list[date | None], expected_median: date | None) -> None ([date(2022, 1, 1)], date(2022, 1, 1)), ([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)], date(2022, 1, 2)), ([date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)], date(2022, 10, 16)), + ([datetime(2022, 1, 1)], datetime(2022, 1, 1)), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + datetime(2022, 1, 2), + ), ( [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], datetime(2022, 10, 16, 16, 0, 0), ), + ([timedelta(days=1)], timedelta(days=1)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=6)), ], ids=[ "empty", "Nones", - "single", - "spread_even", - "spread_skewed", - "spread_skewed_dt", + "single_date", + "spread_even_date", + "spread_skewed_date", + "single_datetime", + "spread_even_datetime", + "spread_skewed_datetime", + "single_duration", + "spread_even_duration", + "spread_skewed_duration", ], ) def test_mean( - values: list[date | datetime | None], expected_mean: date | datetime | None + values: list[TemporalLiteral | None], expected_mean: TemporalLiteral | None ) -> None: s = pl.Series(values) assert s.dt.mean() == expected_mean @@ -991,12 +1024,44 @@ def test_mean( ids=["spread_skewed_dt"], ) def test_datetime_mean_with_tu(values: list[datetime], expected_mean: datetime) -> None: - assert pl.Series(values, dtype=pl.Datetime("ms")).mean() == expected_mean - assert pl.Series(values, dtype=pl.Datetime("ms")).dt.mean() == expected_mean - assert pl.Series(values, dtype=pl.Datetime("us")).mean() == expected_mean - assert pl.Series(values, dtype=pl.Datetime("us")).dt.mean() == expected_mean - assert pl.Series(values, dtype=pl.Datetime("ns")).mean() == expected_mean - assert pl.Series(values, dtype=pl.Datetime("ns")).dt.mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ms")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ms")).dt.mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("us")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("us")).dt.mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ns")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ns")).dt.mean() == expected_mean + + +@pytest.mark.parametrize( + ("values", "expected_mean"), + [([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=6))], + ids=["spread_skewed_dur"], +) +def test_duration_mean_with_tu( + values: list[timedelta], expected_mean: timedelta +) -> None: + assert pl.Series(values, dtype=pl.Duration("ms")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ms")).dt.mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("us")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("us")).dt.mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ns")).mean() == expected_mean + assert pl.Series(values, dtype=pl.Duration("ns")).dt.mean() == expected_mean + + +@pytest.mark.parametrize( + ("values", "expected_median"), + [([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=2))], + ids=["spread_skewed_dur"], +) +def test_duration_median_with_tu( + values: list[timedelta], expected_median: timedelta +) -> None: + assert pl.Series(values, dtype=pl.Duration("ms")).median() == expected_median + assert pl.Series(values, dtype=pl.Duration("ms")).dt.median() == expected_median + assert pl.Series(values, dtype=pl.Duration("us")).median() == expected_median + assert pl.Series(values, dtype=pl.Duration("us")).dt.median() == expected_median + assert pl.Series(values, dtype=pl.Duration("ns")).median() == expected_median + assert pl.Series(values, dtype=pl.Duration("ns")).dt.median() == expected_median def test_agg_expr() -> None: diff --git a/py-polars/tests/unit/namespaces/test_struct.py b/py-polars/tests/unit/namespaces/test_struct.py index 37c284bf9451..ee4806c00188 100644 --- a/py-polars/tests/unit/namespaces/test_struct.py +++ b/py-polars/tests/unit/namespaces/test_struct.py @@ -1,5 +1,8 @@ from __future__ import annotations +import datetime +from collections import OrderedDict + import polars as pl from polars.testing import assert_frame_equal @@ -42,3 +45,41 @@ def test_struct_json_encode() -> None: "a": [{"a": [1, 2], "b": [45]}, {"a": [9, 1, 3], "b": None}], "encoded": ['{"a":[1,2],"b":[45]}', '{"a":[9,1,3],"b":null}'], } + + +def test_struct_json_encode_logical_type() -> None: + df = pl.DataFrame( + { + "a": [ + { + "a": [datetime.date(1997, 1, 1)], + "b": [datetime.datetime(2000, 1, 29, 10, 30)], + "c": [datetime.timedelta(1, 25)], + } + ] + } + ).select(pl.col("a").struct.json_encode().alias("encoded")) + assert df.to_dict(as_series=False) == { + "encoded": ['{"a":["1997-01-01"],"b":["2000-01-29 10:30:00"],"c":["PT86425S"]}'] + } + + +def test_map_fields() -> None: + df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + assert df.schema == OrderedDict([("x", pl.Struct({"a": pl.Int64, "b": pl.Int64}))]) + df = df.select(pl.col("x").name.map_fields(lambda x: x.upper())) + assert df.schema == OrderedDict([("x", pl.Struct({"A": pl.Int64, "B": pl.Int64}))]) + + +def test_prefix_suffix_fields() -> None: + df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + + prefix_df = df.select(pl.col("x").name.prefix_fields("p_")) + assert prefix_df.schema == OrderedDict( + [("x", pl.Struct({"p_a": pl.Int64, "p_b": pl.Int64}))] + ) + + suffix_df = df.select(pl.col("x").name.suffix_fields("_f")) + assert suffix_df.schema == OrderedDict( + [("x", pl.Struct({"a_f": pl.Int64, "b_f": pl.Int64}))] + ) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 317cf68ffba1..5d31bd1e2559 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -204,12 +204,12 @@ def test_boolean_addition() -> None: {"a": [True, False, False], "b": [True, False, True]} ).sum_horizontal() - assert s.dtype == pl.utils.get_index_type() + assert s.dtype == pl.get_index_type() assert s.to_list() == [2, 0, 1] df = pl.DataFrame( {"a": [True], "b": [False]}, ).select(pl.sum_horizontal("a", "b")) - assert df.dtypes == [pl.utils.get_index_type()] + assert df.dtypes == [pl.get_index_type()] def test_bitwise_6311() -> None: @@ -270,3 +270,26 @@ def test_operator_arithmetic_with_nulls(op: Any) -> None: assert_frame_equal(df_expected, op(df, None)) assert_series_equal(s_expected, op(s, None)) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.mod, + operator.mul, + operator.sub, + ], +) +def test_null_column_arithmetic(op: Any) -> None: + df = pl.DataFrame({"a": [None, None], "b": [None, None]}) + expected_df = pl.DataFrame({"a": [None, None]}) + + output_df = df.select(op(pl.col("a"), pl.col("b"))) + assert_frame_equal(expected_df, output_df) + # test broadcast right + output_df = df.select(op(pl.col("a"), pl.Series([None]))) + assert_frame_equal(expected_df, output_df) + # test broadcast left + output_df = df.select(op(pl.Series("a", [None]), pl.col("a"))) + assert_frame_equal(expected_df, output_df) diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 3b3b1ffe8966..a120c4076691 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -133,7 +133,7 @@ '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', ), # --------------------------------------------- - # string expr: case/cast ops + # string exprs # --------------------------------------------- ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.String).str.to_titlecase()'), ( @@ -141,6 +141,21 @@ 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', ), + ( + "b", + "lambda x: x.strip().startswith('#')", + """pl.col("b").str.strip_chars().str.starts_with('#')""", + ), + ( + "b", + """lambda x: x.rstrip().endswith(('!','#','?','"'))""", + """pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""", + ), + ( + "b", + """lambda x: x.lstrip().startswith(('!','#','?',"'"))""", + """pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""", + ), # --------------------------------------------- # json expr: load/extract # --------------------------------------------- @@ -168,17 +183,30 @@ 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', ), # --------------------------------------------- + # temporal attributes/methods + # --------------------------------------------- + ( + "f", + "lambda x: x.isoweekday()", + 'pl.col("f").dt.weekday()', + ), + ( + "f", + "lambda x: x.hour + x.minute + x.second", + '(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()', + ), + # --------------------------------------------- # Bitwise shifts # --------------------------------------------- ( "a", "lambda x: (3 << (32-x)) & 3", - '(3*2**(32 - pl.col("a"))).cast(pl.Int64) & 3', + '(3 * 2**(32 - pl.col("a"))).cast(pl.Int64) & 3', ), ( "a", "lambda x: (x << 32) & 3", - '(pl.col("a")*2**32).cast(pl.Int64) & 3', + '(pl.col("a") * 2**32).cast(pl.Int64) & 3', ), ( "a", @@ -227,6 +255,7 @@ def test_parse_invalid_function(func: str) -> None: ("col", "func", "expr_repr"), TEST_CASES, ) +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: with pytest.warns( PolarsInefficientMapWarning, @@ -243,6 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: "c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'], "d": ["2020-01-01", "2020-01-02", "2020-01-03"], "e": [1.5, 2.4, 3.1], + "f": [ + datetime(1999, 12, 31), + datetime(2024, 5, 6), + datetime(2077, 10, 20), + ], } ) result_frame = df.select( @@ -251,11 +285,16 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: ) expected_frame = df.select( x=pl.col(col), - y=pl.col(col).apply(eval(func)), + y=pl.col(col).map_elements(eval(func)), + ) + assert_frame_equal( + result_frame, + expected_frame, + check_dtype=(".dt." not in suggested_expression), ) - assert_frame_equal(result_frame, expected_frame) +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") def test_parse_apply_raw_functions() -> None: lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]}) @@ -334,7 +373,7 @@ def x10(self, x: pl.Expr) -> pl.Expr: ): pl_series = pl.Series("srs", [0, 1, 2, 3, 4]) assert_series_equal( - pl_series.apply(lambda x: numpy.cos(3) + x - abs(-1)), + pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)), numpy.cos(3) + pl_series - 1, ) @@ -379,7 +418,7 @@ def test_parse_apply_series( suggested_expression = parser.to_expression(s.name) assert suggested_expression == expr_repr - expected_series = s.apply(func) + expected_series = s.map_elements(func) result_series = eval(suggested_expression) assert_series_equal(expected_series, result_series) @@ -409,6 +448,15 @@ def test_expr_exact_warning_message() -> None: assert len(warnings) == 1 +def test_omit_implicit_bool() -> None: + parser = BytecodeParser( + function=lambda x: x and x and x.date(), + map_target="expr", + ) + suggested_expression = parser.to_expression("d") + assert suggested_expression == 'pl.col("d").dt.date()' + + def test_partial_functions_13523() -> None: def plus(value, amount: int): # type: ignore[no-untyped-def] return value + amount diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py index 2cde056cc652..457df189fa00 100644 --- a/py-polars/tests/unit/operations/map/test_map_batches.py +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -77,3 +77,21 @@ def test_map_deprecated() -> None: pl.col("a").map(lambda x: x) with pytest.deprecated_call(): pl.LazyFrame({"a": [1, 2]}).map(lambda x: x) + + +def test_ufunc_args() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) + result = df.select( + z=np.add( # type: ignore[call-overload] + pl.col("a"), pl.col("b") + ) + ) + expected = pl.DataFrame({"z": [3, 6, 9]}) + assert_frame_equal(result, expected) + result = df.select( + z=np.add( # type: ignore[call-overload] + 2, pl.col("a") + ) + ) + expected = pl.DataFrame({"z": [3, 4, 5]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index c3c32caaf1e2..87ced66a1510 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -8,7 +8,7 @@ import polars as pl from polars.exceptions import PolarsInefficientMapWarning -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_map_elements_infer_list() -> None: @@ -79,7 +79,7 @@ def test_datelike_identity() -> None: assert s.map_elements(lambda x: x).to_list() == s.to_list() -def test_map_elements_list_anyvalue_fallback() -> None: +def test_map_elements_list_any_value_fallback() -> None: with pytest.warns( PolarsInefficientMapWarning, match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()', @@ -295,8 +295,22 @@ def test_map_elements_on_empty_col_10639() -> None: } +def test_map_elements_chunked_14390() -> None: + s = pl.concat(2 * [pl.Series([1])], rechunk=False) + assert s.n_chunks() > 1 + assert_series_equal(s.map_elements(str), pl.Series(["1", "1"]), check_names=False) + + def test_apply_deprecated() -> None: with pytest.deprecated_call(): - pl.col("a").apply(lambda x: x + 1) + pl.col("a").apply(np.abs) with pytest.deprecated_call(): - pl.Series([1, 2, 3]).apply(lambda x: x + 1) + pl.Series([1, 2, 3]).apply(np.abs) + + +def test_cabbage_strategy_14396() -> None: + df = pl.DataFrame({"x": [1, 2, 3]}) + with pytest.raises( + ValueError, match="strategy 'cabbage' is not supported" + ), pytest.warns(PolarsInefficientMapWarning): + df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index f45768ee154a..ad588032d8d0 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -387,7 +387,7 @@ def test_agg_filter_over_empty_df_13610() -> None: out = ( ldf.drop_nulls() - .group_by(by=["a"], maintain_order=True) + .group_by(["a"], maintain_order=True) .agg(pl.col("b").filter(pl.col("b").shift(1))) .collect() ) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 4fde74fc8026..4040c112cb24 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -1,12 +1,11 @@ from __future__ import annotations from datetime import date, datetime, time, timedelta -from typing import Any +from typing import TYPE_CHECKING, Any import pytest import polars as pl -from polars.exceptions import ComputeError from polars.testing import assert_frame_equal from polars.testing.asserts.series import assert_series_equal from polars.utils.convert import ( @@ -15,6 +14,9 @@ US_PER_SECOND, ) +if TYPE_CHECKING: + from polars import PolarsDataType + def test_string_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( @@ -28,7 +30,7 @@ def test_string_date() -> None: def test_invalid_string_date() -> None: df = pl.DataFrame({"x1": ["2021-01-aa"]}) - with pytest.raises(ComputeError): + with pytest.raises(pl.ComputeError): df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)}) @@ -64,7 +66,7 @@ def test_string_datetime() -> None: def test_invalid_string_datetime() -> None: df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]}) - with pytest.raises(ComputeError): + with pytest.raises(pl.ComputeError): df.with_columns( **{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))} ) @@ -233,11 +235,11 @@ def test_strict_cast_int( assert _cast_expr(*args) == expected_value # type: ignore[arg-type] assert _cast_lit(*args) == expected_value # type: ignore[arg-type] else: - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_series(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_expr(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_lit(*args) # type: ignore[arg-type] @@ -372,11 +374,11 @@ def test_strict_cast_temporal( assert out.item() == expected_value assert out.dtype == to_dtype else: - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_series_t(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_expr_t(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_lit_t(*args) # type: ignore[arg-type] @@ -568,9 +570,53 @@ def test_strict_cast_string_and_binary( assert out.item() == expected_value assert out.dtype == to_dtype else: - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_series_t(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_expr_t(*args) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.ComputeError): + with pytest.raises(pl.ComputeError): _cast_lit_t(*args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "dtype_in", + [(pl.Categorical), (pl.Enum(["1"]))], +) +@pytest.mark.parametrize( + "dtype_out", + [ + (pl.UInt8), + (pl.Int8), + (pl.UInt16), + (pl.Int16), + (pl.UInt32), + (pl.Int32), + (pl.UInt64), + (pl.Int64), + (pl.Date), + (pl.Datetime), + (pl.Time), + (pl.Duration), + (pl.String), + (pl.Categorical), + (pl.Enum(["1", "2"])), + ], +) +def test_cast_categorical_name_retention( + dtype_in: PolarsDataType, dtype_out: PolarsDataType +) -> None: + assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a" + + +def test_cast_date_to_time() -> None: + s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)]) + msg = "cannot cast `Date` to `Time`" + with pytest.raises(pl.ComputeError, match=msg): + s.cast(pl.Time) + + +def test_cast_time_to_date() -> None: + s = pl.Series([time(0, 0), time(20, 00)]) + msg = "cannot cast `Time` to `Date`" + with pytest.raises(pl.ComputeError, match=msg): + s.cast(pl.Date) diff --git a/py-polars/tests/unit/operations/test_clip.py b/py-polars/tests/unit/operations/test_clip.py index 80e273e8a6c0..a341d3015a67 100644 --- a/py-polars/tests/unit/operations/test_clip.py +++ b/py-polars/tests/unit/operations/test_clip.py @@ -5,45 +5,58 @@ import pytest import polars as pl -from polars.testing.asserts.series import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal -def test_clip() -> None: - clip_exprs = [ +@pytest.fixture() +def clip_exprs() -> list[pl.Expr]: + return [ pl.col("a").clip(pl.col("min"), pl.col("max")).alias("clip"), pl.col("a").clip(lower_bound=pl.col("min")).alias("clip_min"), pl.col("a").clip(upper_bound=pl.col("max")).alias("clip_max"), ] - df = pl.DataFrame( + +def test_clip_int(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( { "a": [1, 2, 3, 4, 5], "min": [0, -1, 4, None, 4], "max": [2, 1, 8, 5, None], } ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [1, 1, 4, None, None], + "clip_min": [1, 2, 4, None, 5], + "clip_max": [1, 1, 3, 4, None], + } + ) + assert_frame_equal(result, expected) - assert df.select(clip_exprs).to_dict(as_series=False) == { - "clip": [1, 1, 4, None, None], - "clip_min": [1, 2, 4, None, 5], - "clip_max": [1, 1, 3, 4, None], - } - df = pl.DataFrame( +def test_clip_float(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( { "a": [1.0, 2.0, 3.0, 4.0, 5.0], "min": [0, -1.0, 4.0, None, 4.0], "max": [2.0, 1.0, 8.0, 5.0, None], } ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [1.0, 1.0, 4.0, None, None], + "clip_min": [1.0, 2.0, 4.0, None, 5.0], + "clip_max": [1.0, 1.0, 3.0, 4.0, None], + } + ) + assert_frame_equal(result, expected) - assert df.select(clip_exprs).to_dict(as_series=False) == { - "clip": [1.0, 1.0, 4.0, None, None], - "clip_min": [1.0, 2.0, 4.0, None, 5.0], - "clip_max": [1.0, 1.0, 3.0, 4.0, None], - } - df = pl.DataFrame( +def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( { "a": [ datetime(1995, 6, 5, 10, 30), @@ -71,33 +84,57 @@ def test_clip() -> None: ], } ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [ + datetime(1995, 6, 5, 10, 30), + datetime(1996, 6, 5), + datetime(2023, 9, 20, 18, 30, 6), + None, + None, + None, + ], + "clip_min": [ + datetime(1995, 6, 5, 10, 30), + datetime(1996, 6, 5), + datetime(2023, 10, 20, 18, 30, 6), + None, + None, + datetime(2000, 1, 10), + ], + "clip_max": [ + datetime(1995, 6, 5, 10, 30), + datetime(1995, 6, 5), + datetime(2023, 9, 20, 18, 30, 6), + None, + datetime(1993, 3, 13), + None, + ], + } + ) + assert_frame_equal(result, expected) + + +def test_clip_non_numeric_dtype_fails() -> None: + msg = "`clip` only supports physical numeric types" + + s = pl.Series(["a", "b", "c"]) + with pytest.raises(pl.InvalidOperationError, match=msg): + s.clip(pl.lit("b"), pl.lit("z")) + + +def test_clip_string_input() -> None: + df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]}) + result = df.select(pl.col("a").clip("min")) + expected = pl.DataFrame({"a": [1, None, 2]}) + assert_frame_equal(result, expected) + - assert df.select(clip_exprs).to_dict(as_series=False) == { - "clip": [ - datetime(1995, 6, 5, 10, 30), - datetime(1996, 6, 5), - datetime(2023, 9, 20, 18, 30, 6), - None, - None, - None, - ], - "clip_min": [ - datetime(1995, 6, 5, 10, 30), - datetime(1996, 6, 5), - datetime(2023, 10, 20, 18, 30, 6), - None, - None, - datetime(2000, 1, 10), - ], - "clip_max": [ - datetime(1995, 6, 5, 10, 30), - datetime(1995, 6, 5), - datetime(2023, 9, 20, 18, 30, 6), - None, - datetime(1993, 3, 13), - None, - ], - } +def test_clip_bound_invalid_for_original_dtype() -> None: + s = pl.Series([1, 2, 3, 4], dtype=pl.UInt32) + with pytest.raises(pl.ComputeError, match="conversion from `i32` to `u32` failed"): + s.clip(-1, 5) def test_clip_min_max_deprecated() -> None: diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py new file mode 100644 index 000000000000..faf0750c689b --- /dev/null +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +from hypothesis import given +from hypothesis.strategies import booleans, floats + +import polars as pl +from polars.expr.expr import _prepare_alpha +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +def test_ewm_mean() -> None: + s = pl.Series([2, 5, 3]) + + expected = pl.Series([2.0, 4.0, 3.4285714285714284]) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([2.0, 3.8, 3.421053]) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected) + + expected = pl.Series([2.0, 3.5, 3.25]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + s = pl.Series([2, 3, 5, 7, 4]) + + expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=False), expected + ) + + expected = pl.Series([None, None, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=False), expected + ) + + s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4]) + + expected = pl.Series( + [ + None, + 1.0, + 3.6666666666666665, + 5.571428571428571, + 5.571428571428571, + 3.6666666666666665, + 4.354838709677419, + 4.174603174603175, + ], + ) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + expected = pl.Series( + [ + None, + 1.0, + 3.666666666666667, + 5.571428571428571, + 5.571428571428571, + 3.08695652173913, + 4.2, + 4.092436974789916, + ] + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.0, 4.0, 4.0]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + +def test_ewm_mean_leading_nulls() -> None: + for min_periods in [1, 2, 3]: + assert ( + pl.Series([1, 2, 3, 4]) + .ewm_mean(com=3, min_periods=min_periods, ignore_nulls=False) + .null_count() + == min_periods - 1 + ) + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_periods=1, ignore_nulls=True + ).to_list() == [None, 1.0, 1.0, 1.0] + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_periods=2, ignore_nulls=True + ).to_list() == [None, None, 1.0, 1.0] + + +def test_ewm_mean_min_periods() -> None: + series = pl.Series([1.0, None, None, None]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1, ignore_nulls=True) + assert ewm_mean.to_list() == [1.0, 1.0, 1.0, 1.0] + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2, ignore_nulls=True) + assert ewm_mean.to_list() == [None, None, None, None] + + series = pl.Series([1.0, None, 2.0, None, 3.0]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + 1.0, + 1.0, + 1.6666666666666665, + 1.6666666666666665, + 2.4285714285714284, + ] + ), + ) + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + None, + None, + 1.6666666666666665, + 1.6666666666666665, + 2.4285714285714284, + ] + ), + ) + + +def test_ewm_std_var() -> None: + series = pl.Series("a", [2, 5, 3]) + + var = series.ewm_var(alpha=0.5, ignore_nulls=False) + std = series.ewm_std(alpha=0.5, ignore_nulls=False) + + assert np.allclose(var, std**2, rtol=1e-16) + + +def test_ewm_param_validation() -> None: + s = pl.Series("values", range(10)) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_std(com=0.5, alpha=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_var(alpha=0.5, span=1.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `com` >= 0"): + s.ewm_std(com=-0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `span` >= 1"): + s.ewm_mean(span=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `half_life` > 0"): + s.ewm_var(half_life=0, ignore_nulls=False) + + for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5): + with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"): + s.ewm_std(alpha=alpha, ignore_nulls=False) + + +# https://github.com/pola-rs/polars/issues/4951 +def test_ewm_with_multiple_chunks() -> None: + df0 = pl.DataFrame( + data=[ + ("w", 6.0, 1.0), + ("x", 5.0, 2.0), + ("y", 4.0, 3.0), + ("z", 3.0, 4.0), + ], + schema=["a", "b", "c"], + ).with_columns( + [ + pl.col(pl.Float64).log().diff().name.prefix("ld_"), + ] + ) + assert df0.n_chunks() == 1 + + # NOTE: We aren't testing whether `select` creates two chunks; + # we just need two chunks to properly test `ewm_mean` + df1 = df0.select(["ld_b", "ld_c"]) + assert df1.n_chunks() == 2 + + ewm_std = df1.with_columns( + pl.all().ewm_std(com=20, ignore_nulls=False).name.prefix("ewm_"), + ) + assert ewm_std.null_count().sum_horizontal()[0] == 4 + + +def alpha_guard(**decay_param: float) -> bool: + """Protects against unnecessary noise in small number regime.""" + if not next(iter(decay_param.values())): + return True + alpha = _prepare_alpha(**decay_param) + return ((1 - alpha) if round(alpha) else alpha) > 1e-6 + + +@given( + s=series( + min_size=4, + dtype=pl.Float64, + null_probability=0.05, + strategy=floats(min_value=-1e8, max_value=1e8), + ), + half_life=floats(min_value=0, max_value=4, exclude_min=True).filter( + lambda x: alpha_guard(half_life=x) + ), + com=floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), + span=floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), + ignore_nulls=booleans(), + adjust=booleans(), + bias=booleans(), +) +def test_ewm_methods( + s: pl.Series, + com: float | None, + span: float | None, + half_life: float | None, + ignore_nulls: bool, + adjust: bool, + bias: bool, +) -> None: + # validate a large set of varied EWM calculations + for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]: + alpha = _prepare_alpha(**decay_param) + + # convert parametrically-generated series to pandas, then use that as a + # reference implementation for comparison (after normalising NaN/None) + p = s.to_pandas() + + # note: skip min_periods < 2, due to pandas-side inconsistency: + # https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178 + for mp in range(2, len(s), len(s) // 3): + # consolidate ewm parameters + pl_params: dict[str, Any] = { + "min_periods": mp, + "adjust": adjust, + "ignore_nulls": ignore_nulls, + } + pl_params.update(decay_param) + pd_params = pl_params.copy() + if "half_life" in pl_params: + pd_params["halflife"] = pd_params.pop("half_life") + if "ignore_nulls" in pl_params: + pd_params["ignore_na"] = pd_params.pop("ignore_nulls") + + # mean: + ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None) + ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean()) + if alpha == 1: + # apply fill-forward to nulls to match pandas + # https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124 + ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward") + + assert_series_equal(ewm_mean_pl, ewm_mean_pd, atol=1e-07) + + # std: + ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None) + ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias)) + assert_series_equal(ewm_std_pl, ewm_std_pd, atol=1e-07) + + # var: + ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None) + ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias)) + assert_series_equal(ewm_var_pl, ewm_var_pd, atol=1e-07) + + +def test_ewm_ignore_nulls_deprecation() -> None: + s = pl.Series([1, None, 3]) + with pytest.deprecated_call(): + s.ewm_mean(com=1.0) + with pytest.deprecated_call(): + s.ewm_std(com=1.0) + with pytest.deprecated_call(): + s.ewm_var(com=1.0) diff --git a/py-polars/tests/unit/operations/test_extend_constant.py b/py-polars/tests/unit/operations/test_extend_constant.py new file mode 100644 index 000000000000..aa6a3bdc2f6b --- /dev/null +++ b/py-polars/tests/unit/operations/test_extend_constant.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.mark.parametrize( + ("const", "dtype"), + [ + (1, pl.Int8), + (4, pl.UInt32), + (4.5, pl.Float32), + (None, pl.Float64), + ("白鵬翔", pl.String), + (date.today(), pl.Date), + (datetime.now(), pl.Datetime("ns")), + (time(23, 59, 59), pl.Time), + (timedelta(hours=7, seconds=123), pl.Duration("ms")), + ], +) +def test_extend_constant(const: Any, dtype: pl.PolarsDataType) -> None: + df = pl.DataFrame({"a": pl.Series("s", [None], dtype=dtype)}) + + expected_df = pl.DataFrame( + {"a": pl.Series("s", [None, const, const, const], dtype=dtype)} + ) + + assert_frame_equal(df.select(pl.col("a").extend_constant(const, 3)), expected_df) + + s = pl.Series("s", [None], dtype=dtype) + expected = pl.Series("s", [None, const, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(const, 3), expected) + + # test n expr + expected = pl.Series("s", [None, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(const, pl.Series([2])), expected) + + # test value expr + expected = pl.Series("s", [None, const, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(pl.Series([const], dtype=dtype), 3), expected) + + +@pytest.mark.parametrize( + ("const", "dtype"), + [ + (1, pl.Int8), + (4, pl.UInt32), + (4.5, pl.Float32), + (None, pl.Float64), + ("白鵬翔", pl.String), + (date.today(), pl.Date), + (datetime.now(), pl.Datetime("ns")), + (time(23, 59, 59), pl.Time), + (timedelta(hours=7, seconds=123), pl.Duration("ms")), + ], +) +def test_extend_constant_arr(const: Any, dtype: pl.PolarsDataType) -> None: + """ + Test extend_constant in pl.List array. + + NOTE: This function currently fails when the Series is a list with a single [None] + value. Hence, this function does not begin with [[None]], but [[const]]. + """ + s = pl.Series("s", [[const]], dtype=pl.List(dtype)) + + expected = pl.Series("s", [[const, const, const, const]], dtype=pl.List(dtype)) + + assert_series_equal(s.list.eval(pl.element().extend_constant(const, 3)), expected) + + +def test_extend_by_not_uint_expr() -> None: + s = pl.Series("s", [1]) + with pytest.raises(pl.ComputeError, match="value and n should have unit length"): + s.extend_constant(pl.Series([2, 3]), 3) + with pytest.raises(pl.ComputeError, match="value and n should have unit length"): + s.extend_constant(2, pl.Series([3, 4])) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index ad68990ff7e2..97ebba213cd2 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -648,19 +648,19 @@ def test_overflow_mean_partitioned_group_by_5194(dtype: pl.PolarsDataType) -> No assert result.to_dict(as_series=False) == expected +# https://github.com/pola-rs/polars/issues/7181 def test_group_by_multiple_column_reference() -> None: - # Issue #7181 df = pl.DataFrame( { "gr": ["a", "b", "a", "b", "a", "b"], "val": [1, 20, 100, 2000, 10000, 200000], } ) - res = df.group_by("gr").agg( + result = df.group_by("gr").agg( pl.col("val") + pl.col("val").shift().fill_null(0), ) - assert res.sort("gr").to_dict(as_series=False) == { + assert result.sort("gr").to_dict(as_series=False) == { "gr": ["a", "b"], "val": [[1, 101, 10100], [20, 2020, 202000]], } @@ -917,3 +917,35 @@ def test_group_by_all_12869() -> None: df = pl.DataFrame({"a": [1]}) result = next(iter(df.group_by(pl.all())))[1] assert_frame_equal(df, result) + + +def test_group_by_named() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)}) + result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min()) + expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg( + pl.col("b").min() + ) + assert_frame_equal(result, expected) + + +def test_group_by_deprecated_by_arg() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)}) + with pytest.deprecated_call(): + result = df.group_by(by=(pl.col("a") * 2), maintain_order=True).agg( + pl.col("b").min() + ) + expected = df.group_by((pl.col("a") * 2), maintain_order=True).agg( + pl.col("b").min() + ) + assert_frame_equal(result, expected) + + +def test_group_by_with_null() -> None: + df = pl.DataFrame( + {"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]} + ) + expected = pl.DataFrame( + {"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]} + ) + output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c")) + assert_frame_equal(expected, output) diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 8805e47f7104..3780410812de 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -312,3 +312,14 @@ def test_is_in_with_wildcard_13809() -> None: out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"])) expected = pl.DataFrame({"A": [False]}) assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])]) +def test_cat_is_in_from_str(dtype: pl.DataType) -> None: + s = pl.Series(["c", "c", "b"], dtype=dtype) + + # test local + assert_series_equal( + pl.Series(["a", "d", "e", "b"]).is_in(s), + pl.Series([False, False, False, True]), + ) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 2c908dc1806b..97b29dd6aeed 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -662,69 +662,86 @@ def test_outer_join_list_() -> None: @pytest.mark.slow() def test_join_validation() -> None: def test_each_join_validation( - unique: pl.DataFrame, duplicate: pl.DataFrame, how: JoinStrategy + unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy ) -> None: # one_to_many _one_to_many_success_inner = unique.join( - duplicate, on="id", how=how, validate="1:m" + duplicate, on=on, how=how, validate="1:m" ) with pytest.raises(pl.ComputeError): _one_to_many_fail_inner = duplicate.join( - unique, on="id", how=how, validate="1:m" + unique, on=on, how=how, validate="1:m" ) # one to one with pytest.raises(pl.ComputeError): _one_to_one_fail_1_inner = unique.join( - duplicate, on="id", how=how, validate="1:1" + duplicate, on=on, how=how, validate="1:1" ) with pytest.raises(pl.ComputeError): _one_to_one_fail_2_inner = duplicate.join( - unique, on="id", how=how, validate="1:1" + unique, on=on, how=how, validate="1:1" ) # many to one with pytest.raises(pl.ComputeError): _many_to_one_fail_inner = unique.join( - duplicate, on="id", how=how, validate="m:1" + duplicate, on=on, how=how, validate="m:1" ) _many_to_one_success_inner = duplicate.join( - unique, on="id", how=how, validate="m:1" + unique, on=on, how=how, validate="m:1" ) # many to many _many_to_many_success_1_inner = duplicate.join( - unique, on="id", how=how, validate="m:m" + unique, on=on, how=how, validate="m:m" ) _many_to_many_success_2_inner = unique.join( - duplicate, on="id", how=how, validate="m:m" + duplicate, on=on, how=how, validate="m:m" ) # test data short_unique = pl.DataFrame( - {"id": [1, 2, 3, 4], "name": ["hello", "world", "rust", "polars"]} + { + "id": [1, 2, 3, 4], + "id_str": ["1", "2", "3", "4"], + "name": ["hello", "world", "rust", "polars"], + } + ) + short_duplicate = pl.DataFrame( + {"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]} ) - short_duplicate = pl.DataFrame({"id": [1, 2, 3, 1], "cnt": [2, 4, 6, 1]}) long_unique = pl.DataFrame( - {"id": [1, 2, 3, 4, 5], "name": ["hello", "world", "rust", "polars", "meow"]} + { + "id": [1, 2, 3, 4, 5], + "id_str": ["1", "2", "3", "4", "5"], + "name": ["hello", "world", "rust", "polars", "meow"], + } + ) + long_duplicate = pl.DataFrame( + { + "id": [1, 2, 3, 1, 5], + "id_str": ["1", "2", "3", "1", "5"], + "cnt": [2, 4, 6, 1, 8], + } ) - long_duplicate = pl.DataFrame({"id": [1, 2, 3, 1, 5], "cnt": [2, 4, 6, 1, 8]}) join_strategies: list[JoinStrategy] = ["inner", "outer", "left"] - for how in join_strategies: - # same size - test_each_join_validation(long_unique, long_duplicate, how) + for join_col in ["id", "id_str"]: + for how in join_strategies: + # same size + test_each_join_validation(long_unique, long_duplicate, join_col, how) - # left longer - test_each_join_validation(long_unique, short_duplicate, how) + # left longer + test_each_join_validation(long_unique, short_duplicate, join_col, how) - # right longer - test_each_join_validation(short_unique, long_duplicate, how) + # right longer + test_each_join_validation(short_unique, long_duplicate, join_col, how) def test_outer_join_bool() -> None: diff --git a/py-polars/tests/unit/operations/test_melt.py b/py-polars/tests/unit/operations/test_melt.py index 12c12c45a581..2d75ab480c1f 100644 --- a/py-polars/tests/unit/operations/test_melt.py +++ b/py-polars/tests/unit/operations/test_melt.py @@ -69,3 +69,15 @@ def test_melt_projection_pd_7747() -> None: } ) assert_frame_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/10075 +def test_melt_no_value_vars() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}) + + result = lf.melt("a") + + expected = pl.LazyFrame( + schema={"a": pl.Int64, "variable": pl.String, "value": pl.Null} + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_pivot.py b/py-polars/tests/unit/operations/test_pivot.py index 097a9f93a453..e5354703052b 100644 --- a/py-polars/tests/unit/operations/test_pivot.py +++ b/py-polars/tests/unit/operations/test_pivot.py @@ -18,11 +18,11 @@ def test_pivot() -> None: df = pl.DataFrame( { "foo": ["A", "A", "B", "B", "C"], - "N": [1, 2, 2, 4, 2], "bar": ["k", "l", "m", "n", "o"], + "N": [1, 2, 2, 4, 2], } ) - result = df.pivot(values="N", index="foo", columns="bar", aggregate_function=None) + result = df.pivot(index="foo", columns="bar", values="N", aggregate_function=None) expected = pl.DataFrame( [ @@ -35,6 +35,35 @@ def test_pivot() -> None: assert_frame_equal(result, expected) +def test_pivot_no_values() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "bar": ["k", "l", "m", "n", "o"], + "N1": [1, 2, 2, 4, 2], + "N2": [1, 2, 2, 4, 2], + } + ) + result = df.pivot(index="foo", columns="bar", values=None, aggregate_function=None) + expected = pl.DataFrame( + { + "foo": ["A", "B", "C"], + "N1_bar_k": [1, None, None], + "N1_bar_l": [2, None, None], + "N1_bar_m": [None, 2, None], + "N1_bar_n": [None, 4, None], + "N1_bar_o": [None, None, 2], + "N2_bar_k": [1, None, None], + "N2_bar_l": [2, None, None], + "N2_bar_m": [None, 2, None], + "N2_bar_n": [None, 4, None], + "N2_bar_o": [None, None, 2], + } + ) + + assert_frame_equal(result, expected) + + def test_pivot_list() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]}) @@ -47,7 +76,11 @@ def test_pivot_list() -> None: } ) out = df.pivot( - "b", index="a", columns="a", aggregate_function="first", sort_columns=True + index="a", + columns="a", + values="b", + aggregate_function="first", + sort_columns=True, ) assert_frame_equal(out, expected) @@ -73,7 +106,7 @@ def test_pivot_aggregate(agg_fn: PivotAgg, expected_rows: list[tuple[Any]]) -> N } ) result = df.pivot( - values="c", index="b", columns="a", aggregate_function=agg_fn, sort_columns=True + index="b", columns="a", values="c", aggregate_function=agg_fn, sort_columns=True ) assert result.rows() == expected_rows @@ -106,12 +139,12 @@ def test_pivot_categorical_index() -> None: schema=[("A", pl.Categorical), ("B", pl.Categorical)], ) - result = df.pivot(values="B", index=["A"], columns="B", aggregate_function="len") + result = df.pivot(index=["A"], columns="B", values="B", aggregate_function="len") expected = {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]} assert result.to_dict(as_series=False) == expected # test expression dispatch - result = df.pivot(values="B", index=["A"], columns="B", aggregate_function=pl.len()) + result = df.pivot(index=["A"], columns="B", values="B", aggregate_function=pl.len()) assert result.to_dict(as_series=False) == expected df = pl.DataFrame( @@ -123,7 +156,7 @@ def test_pivot_categorical_index() -> None: schema=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)], ) result = df.pivot( - values="B", index=["A", "C"], columns="B", aggregate_function="len" + index=["A", "C"], columns="B", values="B", aggregate_function="len" ) expected = { "A": ["Fire", "Water"], @@ -146,17 +179,17 @@ def test_pivot_multiple_values_column_names_5116() -> None: with pytest.raises(ComputeError, match="found multiple elements in the same group"): df.pivot( - values=["x1", "x2"], index="c1", columns="c2", + values=["x1", "x2"], separator="|", aggregate_function=None, ) result = df.pivot( - values=["x1", "x2"], index="c1", columns="c2", + values=["x1", "x2"], separator="|", aggregate_function="first", ) @@ -180,20 +213,97 @@ def test_pivot_duplicate_names_7731() -> None: "e": ["x", "y"], } ) - assert df.pivot( - values=cs.integer(), + result = df.pivot( index=cs.float(), columns=cs.string(), + values=cs.integer(), aggregate_function="first", - ).to_dict(as_series=False) == { + ).to_dict(as_series=False) + expected = { "b": [1.5, 2.5], - "a_c_x": [1, 4], - "d_c_x": [7, 8], - "a_e_x": [1, None], - "a_e_y": [None, 4], - "d_e_x": [7, None], - "d_e_y": [None, 8], + 'a_{"c","e"}_{"x","x"}': [1, None], + 'a_{"c","e"}_{"x","y"}': [None, 4], + 'd_{"c","e"}_{"x","x"}': [7, None], + 'd_{"c","e"}_{"x","y"}': [None, 8], } + assert result == expected + + +def test_pivot_duplicate_names_11663() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": ["x", "x"], "d": ["x", "y"]}) + result = df.pivot(index="b", columns=["c", "d"], values="a").to_dict( + as_series=False + ) + expected = {"b": [1, 2], '{"x","x"}': [1, None], '{"x","y"}': [None, 2]} + assert result == expected + + +def test_pivot_multiple_columns_12407() -> None: + df = pl.DataFrame( + { + "a": ["beep", "bop"], + "b": ["a", "b"], + "c": ["s", "f"], + "d": [7, 8], + "e": ["x", "y"], + } + ) + result = df.pivot( + index="b", columns=["c", "e"], values=["a"], aggregate_function="len" + ).to_dict(as_series=False) + expected = {"b": ["a", "b"], '{"s","x"}': [1, None], '{"f","y"}': [None, 1]} + assert result == expected + + +def test_pivot_struct_13120() -> None: + df = pl.DataFrame( + { + "index": [1, 2, 3, 1, 2, 3], + "item_type": ["a", "a", "a", "b", "b", "b"], + "item_id": [123, 123, 123, 456, 456, 456], + "values": [4, 5, 6, 7, 8, 9], + } + ) + df = df.with_columns(pl.struct(["item_type", "item_id"]).alias("columns")).drop( + "item_type", "item_id" + ) + result = df.pivot(index="index", columns="columns", values="values").to_dict( + as_series=False + ) + expected = {"index": [1, 2, 3], '{"a",123}': [4, 5, 6], '{"b",456}': [7, 8, 9]} + assert result == expected + + +def test_pivot_index_struct_14101() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 1], + "b": [{"a": 1}, {"a": 1}, {"a": 2}], + "c": ["x", "y", "y"], + "d": [1, 1, 3], + } + ) + result = df.pivot(index="b", columns="c", values="a") + expected = pl.DataFrame({"b": [{"a": 1}, {"a": 2}], "x": [1, None], "y": [2, 1]}) + assert_frame_equal(result, expected) + + +def test_pivot_name_already_exists() -> None: + # This should be extremely rare...but still, good to check it + df = pl.DataFrame( + { + "a": ["a", "b"], + "b": ["a", "b"], + '{"a","b"}': [1, 2], + } + ) + with pytest.raises(ComputeError, match="already exists in the DataFrame"): + df.pivot( + values='{"a","b"}', + index="a", + columns=["a", "b"], + aggregate_function="first", + ) def test_pivot_floats() -> None: @@ -208,11 +318,11 @@ def test_pivot_floats() -> None: with pytest.raises(ComputeError, match="found multiple elements in the same group"): result = df.pivot( - values="price", index="weight", columns="quantity", aggregate_function=None + index="weight", columns="quantity", values="price", aggregate_function=None ) result = df.pivot( - values="price", index="weight", columns="quantity", aggregate_function="first" + index="weight", columns="quantity", values="price", aggregate_function="first" ) expected = { "weight": [1.0, 4.4, 8.8], @@ -223,9 +333,9 @@ def test_pivot_floats() -> None: assert result.to_dict(as_series=False) == expected result = df.pivot( - values="price", index=["article", "weight"], columns="quantity", + values="price", aggregate_function=None, ) expected = { @@ -248,21 +358,12 @@ def test_pivot_reinterpret_5907() -> None: ) result = df.pivot( - index=["A"], values=["C"], columns=["B"], aggregate_function=pl.element().sum() + index=["A"], columns=["B"], values=["C"], aggregate_function=pl.element().sum() ) expected = {"A": [3, -2], "x": [100, 50], "y": [500, -80]} assert result.to_dict(as_series=False) == expected -def test_pivot_subclassed_df() -> None: - class SubClassedDataFrame(pl.DataFrame): - pass - - df = SubClassedDataFrame({"a": [1, 2], "b": [3, 4]}) - result = df.pivot(values="b", index="a", columns="a", aggregate_function="first") - assert isinstance(result, SubClassedDataFrame) - - def test_pivot_temporal_logical_types() -> None: date_lst = [datetime(_, 1, 1) for _ in range(1960, 1980)] @@ -312,12 +413,36 @@ def test_pivot_negative_duration() -> None: } -def test_aggregate_function_deprecation_warning() -> None: +def test_aggregate_function_default() -> None: df = pl.DataFrame({"a": [1, 2], "b": ["foo", "foo"], "c": ["x", "x"]}) with pytest.raises( pl.ComputeError, match="found multiple elements in the same group" ): - df.pivot("a", "b", "c") + df.pivot(index="b", columns="c", values="a") + + +def test_pivot_positional_args_deprecated() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } + ) + with pytest.deprecated_call(): + df.pivot("N", "foo", "bar", aggregate_function=None) + + +def test_pivot_aggregate_function_count_deprecated() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } + ) + with pytest.deprecated_call(): + df.pivot(index="foo", columns="bar", values="N", aggregate_function="count") # type: ignore[arg-type] def test_pivot_struct() -> None: @@ -354,3 +479,25 @@ def test_pivot_struct() -> None: {"num1": 4, "num2": 4}, ], } + + +def test_duplicate_column_names_which_should_raise_14305() -> None: + df = pl.DataFrame({"a": [1, 3, 2], "c": ["a", "a", "a"], "d": [7, 8, 9]}) + with pytest.raises(pl.DuplicateError, match="has more than one occurrences"): + df.pivot(index="a", columns="c", values="d") + + +def test_multi_index_containing_struct() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 1], + "b": [{"a": 1}, {"a": 1}, {"a": 2}], + "c": ["x", "y", "y"], + "d": [1, 1, 3], + } + ) + result = df.pivot(index=("b", "d"), columns="c", values="a") + expected = pl.DataFrame( + {"b": [{"a": 1}, {"a": 2}], "d": [1, 3], "x": [1, None], "y": [2, 1]} + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index 71195f46d239..88c11e3bc1d3 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -50,6 +50,7 @@ def test_sample_expr() -> None: def test_sample_df() -> None: df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]}) + assert df.sample().shape == (1, 3) assert df.sample(n=2, seed=0).shape == (2, 3) assert df.sample(fraction=0.4, seed=0).shape == (1, 3) assert df.sample(n=pl.Series([2]), seed=0).shape == (2, 3) @@ -59,6 +60,8 @@ def test_sample_df() -> None: 1, 1, ) + with pytest.raises(ValueError, match="cannot specify both `n` and `fraction`"): + df.sample(n=2, fraction=0.4) def test_sample_n_expr() -> None: diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py index 592bb17673a1..4c7c2b788560 100644 --- a/py-polars/tests/unit/operations/test_rolling.py +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -9,17 +9,18 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import ClosedInterval + from polars.type_aliases import ClosedInterval, PolarsIntegerType -def test_rolling_group_by_overlapping_groups() -> None: +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_group_by_overlapping_groups(dtype: PolarsIntegerType) -> None: # this first aggregates overlapping groups so they cannot be naively flattened df = pl.DataFrame({"a": [41, 60, 37, 51, 52, 39, 40]}) assert_series_equal( ( df.with_row_index() - .with_columns(pl.col("index").cast(pl.Int32)) + .with_columns(pl.col("index").cast(dtype)) .rolling(index_column="index", period="5i") .agg( # trigger the apply on the expression engine @@ -31,12 +32,17 @@ def test_rolling_group_by_overlapping_groups() -> None: @pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()]) -def test_rolling_agg_input_types(input: Any) -> None: - df = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( - "index_column" - ) +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_agg_input_types(input: Any, dtype: PolarsIntegerType) -> None: + df = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}, + schema_overrides={"index_column": dtype}, + ).set_sorted("index_column") result = df.rolling(index_column="index_column", period="2i").agg(input) - expected = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}) + expected = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}, + schema_overrides={"index_column": dtype}, + ) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_sets.py b/py-polars/tests/unit/operations/test_sets.py index 15bfcc0f3d95..88b153dbd99a 100644 --- a/py-polars/tests/unit/operations/test_sets.py +++ b/py-polars/tests/unit/operations/test_sets.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import pytest + import polars as pl @@ -13,3 +17,35 @@ def test_set_intersection_13765() -> None: df = df.filter(pl.col("f") == 1) df.select(pl.col("a").list.set_intersection("a_other")).to_dict(as_series=False) + + +@pytest.mark.parametrize( + ("set_operation", "outcome"), + [ + ("set_difference", [{"z1", "z"}, {"z"}, set(), {"z", "x2"}, {"z", "x3"}]), + ("set_intersection", [{"x", "y"}, {"y"}, {"y", "x"}, {"x", "y"}, set()]), + ( + "set_symmetric_difference", + [{"z1", "z"}, {"x", "z"}, set(), {"z", "x2"}, {"x", "y", "z", "x3"}], + ), + ], +) +def test_set_operations_cats(set_operation: str, outcome: list[set[str]]) -> None: + with pytest.warns(pl.CategoricalRemappingWarning): + df = pl.DataFrame( + { + "a": [ + ["z1", "x", "y", "z"], + ["y", "z"], + ["x", "y"], + ["x", "y", "z", "x2"], + ["z", "x3"], + ] + }, + schema={"a": pl.List(pl.Categorical)}, + ) + df = df.with_columns( + getattr(pl.col("a").list, set_operation)(["x", "y"]).alias("b") + ) + assert df.get_column("b").dtype == pl.List(pl.Categorical) + assert [set(el) for el in df["b"].to_list()] == outcome diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py index 0aa4fea66209..e163d36806ae 100644 --- a/py-polars/tests/unit/operations/test_slice.py +++ b/py-polars/tests/unit/operations/test_slice.py @@ -168,3 +168,48 @@ def test_slice_pushdown_set_sorted() -> None: plan = ldf.explain() # check the set sorted is above slice assert plan.index("set_sorted") < plan.index("SLICE") + + +def test_slice_pushdown_literal_projection_14349() -> None: + lf = pl.select(a=pl.int_range(10)).lazy() + expect = pl.DataFrame({"a": [0, 1, 2, 3, 4], "b": [10, 11, 12, 13, 14]}) + + out = lf.with_columns(b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + out = lf.select("a", b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + assert pl.LazyFrame().select(x=1).head(0).collect().height == 0 + assert pl.LazyFrame().with_columns(x=1).head(0).collect().height == 0 + + q = lf.select(x=1).head(0) + assert q.collect().height == 0 + + # For select, slice pushdown should happen when at least 1 input column is selected + q = lf.select("a", x=1).head(0) + plan = q.explain() + assert plan.index("SELECT") < plan.index("SLICE") + assert q.collect().height == 0 + + # For with_columns, slice pushdown should happen if the input has at least 1 column + q = lf.with_columns(x=1).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 + + q = lf.with_columns(pl.col("a") + 1).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 + + # This does not project any of the original columns + q = lf.with_columns(a=1, b=2).head(0) + plan = q.explain() + assert plan.index("SLICE") < plan.index("WITH_COLUMNS") + assert q.collect().height == 0 + + q = lf.with_columns(b=1, c=2).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 4ea5c3e4059e..a8e60ec192a9 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -789,3 +789,10 @@ def test_sort_with_null_12272() -> None: ) def test_sort_series_nulls_last(input: list[Any], expected: list[Any]) -> None: assert pl.Series(input).sort(nulls_last=True).to_list() == expected + + +def test_sorted_flag_14552() -> None: + a = pl.DataFrame({"a": [2, 1, 3]}) + + a = pl.concat([a, a], rechunk=False) + assert not a.join(a, on="a", how="left")["a"].flags["SORTED_ASC"] diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py index 91d5b388f08f..865466107a01 100644 --- a/py-polars/tests/unit/operations/test_statistics.py +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import timedelta from typing import cast @@ -47,6 +49,15 @@ def test_hist() -> None: ).to_series().to_list() == [0, 3, 4] +@pytest.mark.parametrize("values", [[], [None]]) +def test_hist_empty_or_all_null(values: list[None]) -> None: + ser = pl.Series(values, dtype=pl.Float64) + assert ( + str(ser.hist().to_dict(as_series=False)) + == "{'break_point': [inf], 'category': ['(-inf, inf]'], 'count': [0]}" + ) + + @pytest.mark.parametrize("n", [3, 10, 25]) def test_hist_rand(n: int) -> None: a = pl.Series(np.random.randint(0, 100, n)) diff --git a/py-polars/tests/unit/operations/test_transpose.py b/py-polars/tests/unit/operations/test_transpose.py index 133b974b141e..78a6222dfedf 100644 --- a/py-polars/tests/unit/operations/test_transpose.py +++ b/py-polars/tests/unit/operations/test_transpose.py @@ -200,3 +200,9 @@ def test_transpose_name_from_column_13777() -> None: csv_file = io.BytesIO(b"id,kc\nhi,3") df = pl.read_csv(csv_file).transpose(column_names="id") assert_series_equal(df.to_series(0), pl.Series("hi", [3])) + + +def test_transpose_multiple_chunks() -> None: + df = pl.DataFrame({"a": ["1"]}) + expected = pl.DataFrame({"column_0": ["1"], "column_1": ["1"]}) + assert_frame_equal(df.vstack(df).transpose(), expected) diff --git a/py-polars/tests/unit/operations/test_with_columns.py b/py-polars/tests/unit/operations/test_with_columns.py index c01b73edbe48..29dcb3ef0b84 100644 --- a/py-polars/tests/unit/operations/test_with_columns.py +++ b/py-polars/tests/unit/operations/test_with_columns.py @@ -149,3 +149,19 @@ def test_with_columns_single_series() -> None: expected = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) assert_frame_equal(result.collect(), expected) + + +def test_with_columns_seq() -> None: + df = pl.DataFrame({"a": [1, 2]}) + result = df.with_columns_seq( + pl.lit(5).alias("b"), + pl.lit("foo").alias("c"), + ) + expected = pl.DataFrame( + { + "a": [1, 2], + "b": pl.Series([5, 5], dtype=pl.Int32), + "c": ["foo", "foo"], + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/unique/test_approx_n_unique.py b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py new file mode 100644 index 000000000000..35b9c1598366 --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py @@ -0,0 +1,20 @@ +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_df_approx_n_unique_deprecated() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.DataFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) + + +def test_lf_approx_n_unique_deprecated() -> None: + df = pl.LazyFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.LazyFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py index edd92def87a8..bb01912d5cc7 100644 --- a/py-polars/tests/unit/operations/unique/test_unique.py +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -33,6 +33,20 @@ def test_unique_predicate_pd() -> None: expected = pl.DataFrame({"x": ["abc"], "y": ["xxx"], "z": [True]}) assert_frame_equal(result, expected) + # Issue #14595: filter should not naively be pushed past unique() + for maintain_order in (True, False): + for keep in ("first", "last", "any", "none"): + q = ( + lf.unique("x", maintain_order=maintain_order, keep=keep) # type: ignore[arg-type] + .filter(pl.col("x") == "abc") + .filter(pl.col("z")) + ) + plan = q.explain() + assert r'FILTER col("z")' in plan + # We can push filters if they only depend on the subset columns of unique() + assert r'SELECTION: "[(col(\"x\")) == (String(abc))]"' in plan + assert_frame_equal(q.collect(predicate_pushdown=False), q.collect()) + def test_unique_on_list_df() -> None: assert pl.DataFrame( @@ -126,3 +140,17 @@ def test_unique_categorical(input: list[str | None], output: list[str | None]) - result = s.unique(maintain_order=True) expected = pl.Series(output, dtype=pl.Categorical) assert_series_equal(result, expected) + + +def test_unique_with_null() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2, 3, 4], + "b": ["a", "a", "b", "b", "c", "c"], + "c": [None, None, None, None, None, None], + } + ) + expected_df = pl.DataFrame( + {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "c"], "c": [None, None, None, None]} + ) + assert_frame_equal(df.unique(maintain_order=True), expected_df) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 748c23d1beaf..32e60e8dfc5e 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -8,7 +8,6 @@ import pandas as pd import pyarrow as pa import pytest -from numpy.testing import assert_array_equal import polars import polars as pl @@ -49,6 +48,15 @@ def test_cum_agg() -> None: assert_series_equal(s.cum_prod(), pl.Series("a", [1, 2, 6, 12])) +def test_cum_agg_with_nulls() -> None: + # confirm that known series give expected results + s = pl.Series("a", [None, 2, None, 7, 8, None]) + assert_series_equal(s.cum_sum(), pl.Series("a", [None, 2, None, 9, 17, None])) + assert_series_equal(s.cum_min(), pl.Series("a", [None, 2, None, 2, 2, None])) + assert_series_equal(s.cum_max(), pl.Series("a", [None, 2, None, 7, 8, None])) + assert_series_equal(s.cum_prod(), pl.Series("a", [None, 2, None, 14, 112, None])) + + def test_cum_agg_deprecated() -> None: # confirm that known series give expected results s = pl.Series("a", [1, 2, 3, 2]) @@ -265,10 +273,7 @@ def test_concat() -> None: assert s.len() == 3 -@pytest.mark.parametrize( - "dtype", - [pl.Int64, pl.Float64, pl.String, pl.Boolean], -) +@pytest.mark.parametrize("dtype", [pl.Int64, pl.Float64, pl.String, pl.Boolean]) def test_eq_missing_list_and_primitive(dtype: PolarsDataType) -> None: s1 = pl.Series([None, None], dtype=dtype) s2 = pl.Series([None, None], dtype=pl.List(dtype)) @@ -330,7 +335,7 @@ def test_bitwise_ops() -> None: def test_bitwise_floats_invert() -> None: s = pl.Series([2.0, 3.0, 0.0]) - with pytest.raises(pl.SchemaError): + with pytest.raises(pl.InvalidOperationError): ~s @@ -376,6 +381,23 @@ def test_date_agg() -> None: assert series.max() == date(9009, 9, 9) +@pytest.mark.parametrize( + ("s", "min", "max"), + [ + (pl.Series(["c", "b", "a"], dtype=pl.Categorical("lexical")), "a", "c"), + (pl.Series(["a", "c", "b"], dtype=pl.Categorical), "a", "b"), + (pl.Series([None, "a", "c", "b"], dtype=pl.Categorical("lexical")), "a", "c"), + (pl.Series([None, "c", "a", "b"], dtype=pl.Categorical), "c", "b"), + (pl.Series([], dtype=pl.Categorical("lexical")), None, None), + (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a"])), "c", "a"), + (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a", "d"])), "c", "a"), + ], +) +def test_categorical_agg(s: pl.Series, min: str | None, max: str | None) -> None: + assert s.min() == min + assert s.max() == max + + @pytest.mark.parametrize( "s", [pl.Series([1, 2], dtype=Int64), pl.Series([1, 2], dtype=Float64)] ) @@ -785,120 +807,6 @@ def test_arrow() -> None: ) -def test_ufunc() -> None: - # test if output dtype is calculated correctly. - s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32) - assert_series_equal( - cast(pl.Series, np.multiply(s_float32, 4)), - pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32), - ) - - s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64) - assert_series_equal( - cast(pl.Series, np.multiply(s_float64, 4)), - pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64), - ) - - s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8) - assert_series_equal( - cast(pl.Series, np.power(s_uint8, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8), - ) - assert_series_equal( - cast(pl.Series, np.power(s_uint8, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - assert_series_equal( - cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16), - ) - - s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8) - assert_series_equal( - cast(pl.Series, np.power(s_int8, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8), - ) - assert_series_equal( - cast(pl.Series, np.power(s_int8, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - assert_series_equal( - cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16), - ) - - s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32) - assert_series_equal( - cast(pl.Series, np.power(s_uint32, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32), - ) - assert_series_equal( - cast(pl.Series, np.power(s_uint32, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - - s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32) - assert_series_equal( - cast(pl.Series, np.power(s_int32, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32), - ) - assert_series_equal( - cast(pl.Series, np.power(s_int32, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - - s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64) - assert_series_equal( - cast(pl.Series, np.power(s_uint64, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64), - ) - assert_series_equal( - cast(pl.Series, np.power(s_uint64, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - - s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64) - assert_series_equal( - cast(pl.Series, np.power(s_int64, 2)), - pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64), - ) - assert_series_equal( - cast(pl.Series, np.power(s_int64, 2.0)), - pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - ) - - # test if null bitmask is preserved - a1 = pl.Series("a", [1.0, None, 3.0]) - b1 = cast(pl.Series, np.exp(a1)) - assert b1.null_count() == 1 - - # test if it works with chunked series. - a2 = pl.Series("a", [1.0, None, 3.0]) - b2 = pl.Series("b", [4.0, 5.0, None]) - a2.append(b2) - assert a2.n_chunks() == 2 - c2 = np.multiply(a2, 3) - assert_series_equal( - cast(pl.Series, c2), - pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]), - ) - - # Test if nulls propagate through ufuncs - a3 = pl.Series("a", [None, None, 3, 3]) - b3 = pl.Series("b", [None, 3, None, 3]) - assert_series_equal( - cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3]) - ) - - -def test_numpy_string_array() -> None: - s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String) - assert_array_equal( - np.char.capitalize(s_str), - np.array(["Aa", "Bb", "Cc", "Dd"], dtype=" None: a = pl.Series("a", [1, 2, 3]) pos_idxs = pl.Series("idxs", [2, 0, 1, 0], dtype=pl.Int8) @@ -1447,41 +1355,6 @@ def test_bitwise() -> None: a or b # type: ignore[redundant-expr] -def test_to_numpy(monkeypatch: Any) -> None: - for writable in [False, True]: - for flag in [False, True]: - monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", flag) - - np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable) - - np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8)) - # Test if numpy array is readonly or writable. - assert np_array.flags.writeable == writable - - if writable: - np_array[1] += 10 - np.testing.assert_array_equal( - np_array, np.array([1, 12, 3], dtype=np.uint8) - ) - - np_array_with_missing_values = pl.Series( - "a", [None, 2, 3], pl.UInt8 - ).to_numpy(writable=writable) - - np.testing.assert_array_equal( - np_array_with_missing_values, - np.array( - [np.nan, 2.0, 3.0], - dtype=(np.float64 if flag is True else np.float32), - ), - ) - - if writable: - # As Null values can't be encoded natively in a numpy array, - # this array will never be a view. - assert np_array_with_missing_values.flags.writeable == writable - - def test_from_generator_or_iterable() -> None: # generator function def gen(n: int) -> Iterator[int]: @@ -1705,76 +1578,62 @@ def test_arg_sort() -> None: assert_series_equal(s.arg_sort(descending=True), expected_descending) -def test_arg_min_and_arg_max() -> None: - # numerical no null. - s = pl.Series([5, 3, 4, 1, 2]) - assert s.arg_min() == 3 - assert s.arg_max() == 0 - - # numerical has null. - s = pl.Series([None, 5, 1]) - assert s.arg_min() == 2 - assert s.arg_max() == 1 - - # numerical all null. - s = pl.Series([None, None], dtype=Int32) - assert s.arg_min() is None - assert s.arg_max() is None - - # boolean no null. - s = pl.Series([True, False]) - assert s.arg_min() == 1 - assert s.arg_max() == 0 - s = pl.Series([True, True]) - assert s.arg_min() == 0 - assert s.arg_max() == 0 - s = pl.Series([False, False]) - assert s.arg_min() == 0 - assert s.arg_max() == 0 - - # boolean has null. - s = pl.Series([None, True, False, True]) - assert s.arg_min() == 2 - assert s.arg_max() == 1 - s = pl.Series([None, True, True]) - assert s.arg_min() == 1 - assert s.arg_max() == 1 - s = pl.Series([None, False, False]) - assert s.arg_min() == 1 - assert s.arg_max() == 1 - - # boolean all null. - s = pl.Series([None, None], dtype=pl.Boolean) - assert s.arg_min() is None - assert s.arg_max() is None - - # str no null - s = pl.Series(["a", "c", "b"]) - assert s.arg_min() == 0 - assert s.arg_max() == 1 +@pytest.mark.parametrize( + ("series", "argmin", "argmax"), + [ + # Numeric + (pl.Series([5, 3, 4, 1, 2]), 3, 0), + (pl.Series([None, 5, 1]), 2, 1), + # Boolean + (pl.Series([True, False]), 1, 0), + (pl.Series([True, True]), 0, 0), + (pl.Series([False, False]), 0, 0), + (pl.Series([None, True, False, True]), 2, 1), + (pl.Series([None, True, True]), 1, 1), + (pl.Series([None, False, False]), 1, 1), + # String + (pl.Series(["a", "c", "b"]), 0, 1), + (pl.Series([None, "a", None, "b"]), 1, 3), + # Categorical + (pl.Series(["c", "b", "a"], dtype=pl.Categorical), 0, 2), + (pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical), 1, 4), + (pl.Series(["c", "b", "a"], dtype=pl.Categorical(ordering="lexical")), 2, 0), + ( + pl.Series( + [None, "c", "b", None, "a"], dtype=pl.Categorical(ordering="lexical") + ), + 4, + 1, + ), + ], +) +def test_arg_min_arg_max(series: pl.Series, argmin: int, argmax: int) -> None: + assert series.arg_min() == argmin + assert series.arg_max() == argmax - # str has null - s = pl.Series([None, "a", None, "b"]) - assert s.arg_min() == 1 - assert s.arg_max() == 3 - # str all null - s = pl.Series([None, None], dtype=pl.String) - assert s.arg_min() is None - assert s.arg_max() is None +@pytest.mark.parametrize( + ("series"), + [ + # All nulls + pl.Series([None, None], dtype=pl.Int32), + pl.Series([None, None], dtype=pl.Boolean), + pl.Series([None, None], dtype=pl.String), + pl.Series([None, None], dtype=pl.Categorical), + pl.Series([None, None], dtype=pl.Categorical(ordering="lexical")), + # Empty Series + pl.Series([], dtype=pl.Int32), + pl.Series([], dtype=pl.Boolean), + pl.Series([], dtype=pl.String), + pl.Series([], dtype=pl.Categorical), + ], +) +def test_arg_min_arg_max_all_nulls_or_empty(series: pl.Series) -> None: + assert series.arg_min() is None + assert series.arg_max() is None - # test ascending and descending series - s = pl.Series([None, 1, 2, 3, 4, 5]) - s.sort(in_place=True) # set ascending sorted flag - assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} - assert s.arg_min() == 1 - assert s.arg_max() == 5 - s = pl.Series([None, 5, 4, 3, 2, 1]) - s.sort(descending=True, in_place=True) # set descing sorted flag - assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} - assert s.arg_min() == 5 - assert s.arg_max() == 1 +def test_arg_min_and_arg_max_sorted() -> None: # test ascending and descending numerical series s = pl.Series([None, 1, 2, 3, 4, 5]) s.sort(in_place=True) # set ascending sorted flag @@ -1799,56 +1658,6 @@ def test_arg_min_and_arg_max() -> None: assert s.arg_min() == 5 assert s.arg_max() == 1 - # test numerical empty series - s = pl.Series([], dtype=pl.Int32) - assert s.arg_min() is None - assert s.arg_max() is None - - # test boolean empty series - s = pl.Series([], dtype=pl.Boolean) - assert s.arg_min() is None - assert s.arg_max() is None - - # test str empty series - s = pl.Series([], dtype=pl.String) - assert s.arg_min() is None - assert s.arg_max() is None - - # categorical empty series - s = pl.Series([], dtype=pl.Categorical) - assert s.arg_min() is None - assert s.arg_max() is None - - # categorical with physical ordering no null - s = pl.Series(["c", "b", "a"], dtype=pl.Categorical) - assert s.arg_min() == 0 - assert s.arg_max() == 2 - - # categorical with physical ordering has null - s = pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical) - assert s.arg_min() == 1 - assert s.arg_max() == 4 - - # categorical with physical all null - s = pl.Series([None, None], dtype=pl.Categorical) - assert s.arg_min() is None - assert s.arg_max() is None - - # categorical with lexical ordering no null - s = pl.Series(["c", "b", "a"], dtype=pl.Categorical(ordering="lexical")) - assert s.arg_min() == 2 - assert s.arg_max() == 0 - - # categorical with lexical ordering has null - s = pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical(ordering="lexical")) - assert s.arg_min() == 4 - assert s.arg_max() == 1 - - # categorical with lexical ordering all null - s = pl.Series([None, None], dtype=pl.Categorical(ordering="lexical")) - assert s.arg_min() is None - assert s.arg_max() is None - def test_is_null_is_not_null() -> None: s = pl.Series("a", [1.0, 2.0, 3.0, None]) @@ -2100,196 +1909,6 @@ def test_trigonometric_invalid_input() -> None: s.cosh() -def test_ewm_mean() -> None: - s = pl.Series([2, 5, 3]) - - expected = pl.Series([2.0, 4.0, 3.4285714285714284]) - assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected - ) - - expected = pl.Series([2.0, 3.8, 3.421053]) - assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected) - assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected) - - expected = pl.Series([2.0, 3.5, 3.25]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected - ) - - s = pl.Series([2, 3, 5, 7, 4]) - - expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=False), expected - ) - - expected = pl.Series([None, None, 4.0, 5.6, 4.774194]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=False), expected - ) - - s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4]) - - expected = pl.Series( - [ - None, - 1.0, - 3.6666666666666665, - 5.571428571428571, - 5.571428571428571, - 3.6666666666666665, - 4.354838709677419, - 4.174603174603175, - ], - ) - assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) - expected = pl.Series( - [ - None, - 1.0, - 3.666666666666667, - 5.571428571428571, - 5.571428571428571, - 3.08695652173913, - 4.2, - 4.092436974789916, - ] - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected - ) - - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected - ) - - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.0, 4.0, 4.0]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected - ) - - -def test_ewm_mean_leading_nulls() -> None: - for min_periods in [1, 2, 3]: - assert ( - pl.Series([1, 2, 3, 4]) - .ewm_mean(com=3, min_periods=min_periods) - .null_count() - == min_periods - 1 - ) - assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( - alpha=0.5, min_periods=1 - ).to_list() == [None, 1.0, 1.0, 1.0] - assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( - alpha=0.5, min_periods=2 - ).to_list() == [None, None, 1.0, 1.0] - - -def test_ewm_mean_min_periods() -> None: - series = pl.Series([1.0, None, None, None]) - - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1) - assert ewm_mean.to_list() == [1.0, 1.0, 1.0, 1.0] - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2) - assert ewm_mean.to_list() == [None, None, None, None] - - series = pl.Series([1.0, None, 2.0, None, 3.0]) - - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1) - assert_series_equal( - ewm_mean, - pl.Series( - [ - 1.0, - 1.0, - 1.6666666666666665, - 1.6666666666666665, - 2.4285714285714284, - ] - ), - ) - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2) - assert_series_equal( - ewm_mean, - pl.Series( - [ - None, - None, - 1.6666666666666665, - 1.6666666666666665, - 2.4285714285714284, - ] - ), - ) - - -def test_ewm_std_var() -> None: - series = pl.Series("a", [2, 5, 3]) - - var = series.ewm_var(alpha=0.5) - std = series.ewm_std(alpha=0.5) - - assert np.allclose(var, std**2, rtol=1e-16) - - -def test_ewm_param_validation() -> None: - s = pl.Series("values", range(10)) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_std(com=0.5, alpha=0.5) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_mean(span=1.5, half_life=0.75) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_var(alpha=0.5, span=1.5) - - with pytest.raises(ValueError, match="require `com` >= 0"): - s.ewm_std(com=-0.5) - - with pytest.raises(ValueError, match="require `span` >= 1"): - s.ewm_mean(span=0.5) - - with pytest.raises(ValueError, match="require `half_life` > 0"): - s.ewm_var(half_life=0) - - for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5): - with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"): - s.ewm_std(alpha=alpha) - - -@pytest.mark.parametrize( - ("const", "dtype"), - [ - (1, pl.Int8), - (4, pl.UInt32), - (4.5, pl.Float32), - (None, pl.Float64), - ("白鵬翔", pl.String), - (date.today(), pl.Date), - (datetime.now(), pl.Datetime("ns")), - (time(23, 59, 59), pl.Time), - (timedelta(hours=7, seconds=123), pl.Duration("ms")), - ], -) -def test_extend_constant(const: Any, dtype: pl.PolarsDataType) -> None: - s = pl.Series("s", [None], dtype=dtype) - expected = pl.Series("s", [None, const, const, const], dtype=dtype) - assert_series_equal(s.extend_constant(const, 3), expected) - - def test_product() -> None: a = pl.Series("a", [1, 2, 3]) out = a.product() diff --git a/py-polars/tests/unit/series/test_to_numpy.py b/py-polars/tests/unit/series/test_to_numpy.py deleted file mode 100644 index e245009e7171..000000000000 --- a/py-polars/tests/unit/series/test_to_numpy.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -import numpy as np -from hypothesis import given, settings -from numpy.testing import assert_array_equal - -import polars as pl -from polars.testing.parametric import series - - -@given( - s=series( - min_size=1, max_size=10, excluded_dtypes=[pl.Categorical, pl.List, pl.Struct] - ).filter( - lambda s: ( - getattr(s.dtype, "time_unit", None) != "ms" - and not (s.dtype == pl.String and s.str.contains("\x00").any()) - and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any()) - ) - ), -) -@settings(max_examples=250) -def test_series_to_numpy(s: pl.Series) -> None: - result = s.to_numpy() - - values = s.to_list() - dtype_map = { - pl.Datetime("ns"): "datetime64[ns]", - pl.Datetime("us"): "datetime64[us]", - pl.Duration("ns"): "timedelta64[ns]", - pl.Duration("us"): "timedelta64[us]", - } - np_dtype = dtype_map.get(s.dtype) # type: ignore[call-overload] - expected = np.array(values, dtype=np_dtype) - - assert_array_equal(result, expected) diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index da8cdeff784b..9659c720ce84 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date, datetime, time -from typing import Any +from typing import Any, Literal import pytest @@ -32,6 +32,30 @@ def test_date() -> None: assert_frame_equal(result, expected) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None: + df = pl.DataFrame( + { + "dtm": [ + datetime(2099, 12, 31, 23, 59, 59), + datetime(1999, 12, 31, 12, 30, 30), + datetime(1969, 12, 31, 1, 1, 1), + datetime(1899, 12, 31, 0, 0, 0), + ], + }, + schema={"dtm": pl.Datetime(time_unit)}, + ) + with pl.SQLContext(df=df, eager_execution=True) as ctx: + result = ctx.execute("SELECT dtm::time as tm from df")["tm"].to_list() + + assert result == [ + time(23, 59, 59), + time(12, 30, 30), + time(1, 1, 1), + time(0, 0, 0), + ] + + @pytest.mark.parametrize( ("part", "dtype", "expected"), [ @@ -130,7 +154,6 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: ("ms", [1704589323123, 1609324245987, 1136159999555]), ("us", [1704589323123456, 1609324245987654, 1136159999555555]), ("ns", [1704589323123456000, 1609324245987654000, 1136159999555555000]), - (None, [1704589323123456, 1609324245987654, 1136159999555555]), ], ) def test_timestamp_time_unit(unit: str | None, expected: list[int]) -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index fd18289fdc86..165183673471 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import tempfile import time from datetime import date from pathlib import Path @@ -332,19 +331,32 @@ def test_streaming_11219() -> None: @pytest.mark.write_disk() -def test_custom_temp_dir(monkeypatch: Any) -> None: - test_temp_dir = "test_temp_dir" - temp_dir = Path(tempfile.gettempdir()) / test_temp_dir +def test_streaming_csv_headers_but_no_data_13770(tmp_path: Path) -> None: + with Path.open(tmp_path / "header_no_data.csv", "w") as f: + f.write("name, age\n") + + schema = {"name": pl.String, "age": pl.Int32} + df = ( + pl.scan_csv(tmp_path / "header_no_data.csv", schema=schema) + .head() + .collect(streaming=True) + ) + assert len(df) == 0 + assert df.schema == schema - monkeypatch.setenv("POLARS_VERBOSE", "1") + +@pytest.mark.write_disk() +def test_custom_temp_dir(tmp_path: Path, monkeypatch: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") - monkeypatch.setenv("POLARS_TEMP_DIR", str(temp_dir)) + monkeypatch.setenv("POLARS_VERBOSE", "1") s = pl.arange(0, 100_000, eager=True).rename("idx") df = s.shuffle().to_frame() df.lazy().sort("idx").collect(streaming=True) - assert os.listdir(temp_dir), f"Temp directory '{temp_dir}' is empty" + assert os.listdir(tmp_path), f"Temp directory '{tmp_path}' is empty" @pytest.mark.write_disk() diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 4a40c07c8e7a..7c13b7a05804 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -9,6 +9,9 @@ import polars as pl from polars.testing import assert_frame_equal +if TYPE_CHECKING: + from pathlib import Path + pytestmark = pytest.mark.xdist_group("streaming") @@ -202,15 +205,17 @@ def random_integers() -> pl.Series: @pytest.mark.write_disk() def test_streaming_group_by_ooc_q1( - monkeypatch: Any, random_integers: pl.Series + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, ) -> None: - s = random_integers + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") + lf = random_integers.to_frame().lazy() result = ( - s.to_frame() - .lazy() - .group_by("a") + lf.group_by("a") .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) .sort("a") .collect(streaming=True) @@ -228,16 +233,17 @@ def test_streaming_group_by_ooc_q1( @pytest.mark.write_disk() def test_streaming_group_by_ooc_q2( - monkeypatch: Any, random_integers: pl.Series + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, ) -> None: - s = random_integers + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") + lf = random_integers.cast(str).to_frame().lazy() result = ( - s.cast(str) - .to_frame() - .lazy() - .group_by("a") + lf.group_by("a") .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) .sort("a") .collect(streaming=True) @@ -255,15 +261,17 @@ def test_streaming_group_by_ooc_q2( @pytest.mark.write_disk() def test_streaming_group_by_ooc_q3( - monkeypatch: Any, random_integers: pl.Series + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, ) -> None: - s = random_integers + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") + lf = pl.LazyFrame({"a": random_integers, "b": random_integers}) result = ( - pl.DataFrame({"a": s, "b": s}) - .lazy() - .group_by(["a", "b"]) + lf.group_by("a", "b") .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) .sort("a") .collect(streaming=True) diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index 47a3a09426e3..d405fec1183c 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -112,6 +112,14 @@ def test_sink_csv(io_files_path: Path, tmp_path: Path) -> None: assert_frame_equal(target_data, source_data) +@pytest.mark.write_disk() +def test_sink_csv_14494(tmp_path: Path) -> None: + pl.LazyFrame({"c": [1, 2, 3]}, schema={"c": pl.Int64}).filter( + pl.col("c") > 10 + ).sink_csv(tmp_path / "sink.csv") + assert pl.read_csv(tmp_path / "sink.csv").columns == ["c"] + + def test_sink_csv_with_options() -> None: """ Test with all possible options. diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index b80847803741..690c60ce4e79 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -190,3 +190,18 @@ def test_join_null_matches_multiple_keys(streaming: bool) -> None: assert_frame_equal( df_a.join(df_b, on=["a", "idx"], how="outer").sort("a").collect(), expected ) + + +def test_streaming_join_and_union() -> None: + a = pl.LazyFrame({"a": [1, 2]}) + + b = pl.LazyFrame({"a": [1, 2, 4, 8]}) + + c = a.join(b, on="a") + # The join node latest ensures that the dispatcher + # needs to replace placeholders in unions. + q = pl.concat([a, b, c]) + + out = q.collect(streaming=True) + assert_frame_equal(out, q.collect(streaming=False)) + assert out.to_series().to_list() == [1, 2, 1, 2, 4, 8, 1, 2] diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index c9befd8362c3..3038d9dcbe14 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -75,7 +75,9 @@ def test_streaming_sort_multiple_columns_logical_types() -> None: @pytest.mark.write_disk() @pytest.mark.slow() -def test_ooc_sort(monkeypatch: Any) -> None: +def test_ooc_sort(tmp_path: Path, monkeypatch: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") s = pl.arange(0, 100_000, eager=True).rename("idx") @@ -91,9 +93,11 @@ def test_ooc_sort(monkeypatch: Any) -> None: @pytest.mark.write_disk() -def test_streaming_sort(monkeypatch: Any, capfd: Any) -> None: - monkeypatch.setenv("POLARS_VERBOSE", "1") +def test_streaming_sort(tmp_path: Path, monkeypatch: Any, capfd: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") + monkeypatch.setenv("POLARS_VERBOSE", "1") # this creates a lot of duplicate partitions and triggers: #7568 assert ( pl.Series(np.random.randint(0, 100, 100)) @@ -108,11 +112,13 @@ def test_streaming_sort(monkeypatch: Any, capfd: Any) -> None: @pytest.mark.write_disk() -def test_out_of_core_sort_9503(monkeypatch: Any) -> None: +def test_out_of_core_sort_9503(tmp_path: Path, monkeypatch: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") np.random.seed(0) - num_rows = 1_00_000 + num_rows = 100_000 num_columns = 2 num_tables = 10 @@ -162,15 +168,13 @@ def test_out_of_core_sort_9503(monkeypatch: Any) -> None: } -@pytest.mark.skip( - reason="This test is unreliable - it fails intermittently in our CI" - " with 'OSError: No such file or directory (os error 2)'." -) @pytest.mark.write_disk() @pytest.mark.slow() def test_streaming_sort_multiple_columns( - str_ints_df: pl.DataFrame, monkeypatch: Any, capfd: Any + str_ints_df: pl.DataFrame, tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") monkeypatch.setenv("POLARS_VERBOSE", "1") df = str_ints_df diff --git a/py-polars/tests/unit/streaming/test_streaming_unique.py b/py-polars/tests/unit/streaming/test_streaming_unique.py index c79a734464a3..77d7534548dd 100644 --- a/py-polars/tests/unit/streaming/test_streaming_unique.py +++ b/py-polars/tests/unit/streaming/test_streaming_unique.py @@ -16,8 +16,10 @@ @pytest.mark.write_disk() @pytest.mark.slow() def test_streaming_out_of_core_unique( - io_files_path: Path, monkeypatch: Any, capfd: Any + io_files_path: Path, tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") monkeypatch.setenv("POLARS_VERBOSE", "1") monkeypatch.setenv("POLARS_STREAMING_GROUPBY_SPILL_SIZE", "256") diff --git a/py-polars/tests/unit/test_arity.py b/py-polars/tests/unit/test_arity.py index 4be2a1910fe9..ea62e6583cae 100644 --- a/py-polars/tests/unit/test_arity.py +++ b/py-polars/tests/unit/test_arity.py @@ -79,3 +79,25 @@ def test_broadcast_string_ops_12632( assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3 assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3 assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3 + + +def test_negate_inlined_14278() -> None: + df = pl.DataFrame( + {"group": ["A", "A", "B", "B", "B", "C", "C"], "value": [1, 2, 3, 4, 5, 6, 7]} + ) + + agg_expr = [ + pl.struct("group", "value").tail(2).alias("list"), + pl.col("value").sort().tail(2).count().alias("count"), + ] + + q = df.lazy().group_by("group").agg(agg_expr) + assert q.collect().sort("group").to_dict(as_series=False) == { + "group": ["A", "B", "C"], + "list": [ + [{"group": "A", "value": 1}, {"group": "A", "value": 2}], + [{"group": "B", "value": 4}, {"group": "B", "value": 5}], + [{"group": "C", "value": 6}, {"group": "C", "value": 7}], + ], + "count": [2, 2, 2], + } diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index 17b58c7201c9..f4ed561a3b4e 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -87,30 +87,6 @@ def test_hide_header_elements() -> None: ) -def test_html_tables() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - - # default: header contains names/dtypes - header = "abci64i64i64" - assert header in df._repr_html_() - - # validate that relevant config options are respected - with pl.Config(tbl_hide_column_names=True): - header = "i64i64i64" - assert header in df._repr_html_() - - with pl.Config(tbl_hide_column_data_types=True): - header = "abc" - assert header in df._repr_html_() - - with pl.Config( - tbl_hide_column_data_types=True, - tbl_hide_column_names=True, - ): - header = "" - assert header in df._repr_html_() - - def test_set_tbl_cols() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) @@ -196,7 +172,7 @@ def test_set_tbl_rows() -> None: "╞═════╪═════╪═════╡\n" "│ 1 ┆ 5 ┆ 9 │\n" "│ 2 ┆ 6 ┆ 10 │\n" - "│ 3 ┆ 7 ┆ 11 │\n" + "│ … ┆ … ┆ … │\n" "│ 4 ┆ 8 ┆ 12 │\n" "└─────┴─────┴─────┘" ) @@ -205,8 +181,8 @@ def test_set_tbl_rows() -> None: "Series: 'ser' [i64]\n" "[\n" "\t1\n" + "\t2\n" "\t…\n" - "\t4\n" "\t5\n" "]" ) @@ -231,7 +207,7 @@ def test_set_tbl_rows() -> None: "[\n" "\t1\n" "\t2\n" - "\t3\n" + "\t…\n" "\t4\n" "\t5\n" "]" @@ -254,8 +230,8 @@ def test_set_tbl_rows() -> None: "│ i64 ┆ i64 ┆ i64 │\n" "╞═════╪═════╪═════╡\n" "│ 1 ┆ 6 ┆ 11 │\n" + "│ 2 ┆ 7 ┆ 12 │\n" "│ … ┆ … ┆ … │\n" - "│ 4 ┆ 9 ┆ 14 │\n" "│ 5 ┆ 10 ┆ 15 │\n" "└─────┴─────┴─────┘" ) diff --git a/py-polars/tests/unit/test_consortium_standard.py b/py-polars/tests/unit/test_consortium_standard.py deleted file mode 100644 index 0c3c5bf6b013..000000000000 --- a/py-polars/tests/unit/test_consortium_standard.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Test some basic methods of the dataframe consortium standard. - -Full testing is done at https://github.com/data-apis/dataframe-api-compat, -this is just to check that the entry point works as expected. -""" - -import polars as pl - - -def test_dataframe() -> None: - df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - df = df_pl.__dataframe_consortium_standard__() - result = df.get_column_names() - expected = ["a", "b"] - assert result == expected - - -def test_lazyframe() -> None: - df_pl = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - df = df_pl.__dataframe_consortium_standard__() - result = df.get_column_names() - expected = ["a", "b"] - assert result == expected - - -def test_series() -> None: - ser = pl.Series("a", [1, 2, 3]) - col = ser.__column_consortium_standard__() - assert col.name == "a" diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 3ea9ea029fe8..b12b5f3fab18 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -1,5 +1,6 @@ import re -from datetime import date, datetime +import typing +from datetime import date, datetime, timedelta from tempfile import NamedTemporaryFile from typing import Any @@ -594,3 +595,52 @@ def test_cse_11958() -> None: "diff3": [None, None, None, 30, 30], "diff4": [None, None, None, None, 40], } + + +@typing.no_type_check +def test_cse_14047() -> None: + ldf = pl.LazyFrame( + { + "timestamp": pl.datetime_range( + datetime(2024, 1, 12), + datetime(2024, 1, 12, 0, 0, 0, 150_000), + "10ms", + eager=True, + closed="left", + ), + "price": list(range(15)), + } + ) + + def count_diff( + price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001 + ): + span_end_to_curr = ( + price.count() + .cast(int) + .rolling("timestamp", period=timedelta(seconds=lower_bound)) + ) + span_start_to_curr = ( + price.count() + .cast(int) + .rolling("timestamp", period=timedelta(seconds=upper_bound)) + ) + return (span_start_to_curr - span_end_to_curr).alias( + f"count_diff_{upper_bound}_{lower_bound}" + ) + + def s_per_count(count_diff, span) -> pl.Expr: + return (span[1] * 1000 - span[0] * 1000) / count_diff + + spans = [(0.001, 0.1), (1, 10)] + count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans] + s_per_count_exprs = [ + s_per_count(count_diff, span).alias(f"zz_{span}") + for count_diff, span in zip(count_diff_exprs, spans) + ] + + exprs = count_diff_exprs + s_per_count_exprs + ldf = ldf.with_columns(*exprs) + assert_frame_equal( + ldf.collect(comm_subexpr_elim=True), ldf.collect(comm_subexpr_elim=False) + ) diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 59bf304a2bb1..1c26a1aaaea9 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -100,10 +100,6 @@ def test_dtype_groups() -> None: assert pl.Datetime("ms", "Asia/Tokyo") in grp -def test_get_index_type() -> None: - assert pl.get_index_type() == pl.UInt32 - - def test_dtypes_picklable() -> None: parametric_type = pl.Datetime("ns") singleton_type = pl.Float64 diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index f6b5b2df8110..06e8a906617d 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -464,7 +464,7 @@ def test_with_column_duplicates() -> None: df = pl.DataFrame({"a": [0, None, 2, 3, None], "b": [None, 1, 2, 3, None]}) with pytest.raises( pl.ComputeError, - match=r"The name: 'same' passed to `LazyFrame.with_columns` is duplicate", + match=r"the name: 'same' passed to `LazyFrame.with_columns` is duplicate.*", ): assert df.with_columns([pl.all().alias("same")]).columns == ["a", "b", "same"] @@ -689,3 +689,13 @@ def test_error_list_to_array() -> None: pl.DataFrame( data={"a": [[1, 2], [3, 4, 5]]}, schema={"a": pl.List(pl.Int8)} ).with_columns(array=pl.col("a").list.to_array(2)) + + +# https://github.com/pola-rs/polars/issues/8079 +def test_error_lazyframe_not_repeating() -> None: + lf = pl.LazyFrame({"a": 1, "b": range(2)}) + with pytest.raises(pl.ColumnNotFoundError) as exc_info: + lf.select("c").select("d").select("e").collect() + + match = "Error originated just after this operation:" + assert str(exc_info).count(match) == 1 diff --git a/py-polars/tests/unit/test_format.py b/py-polars/tests/unit/test_format.py index d37b2dbb5d6d..c403e2af7de4 100644 --- a/py-polars/tests/unit/test_format.py +++ b/py-polars/tests/unit/test_format.py @@ -1,13 +1,16 @@ from __future__ import annotations from decimal import Decimal as D -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any, Iterator import pytest import polars as pl from polars import ComputeError +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType + @pytest.fixture(autouse=True) def _environ() -> Iterator[None]: @@ -58,22 +61,7 @@ def _environ() -> Iterator[None]: 2 3 4 - 5 - 6 - 7 - 8 - 9 - 10 - 11 … - 87 - 88 - 89 - 90 - 91 - 92 - 93 - 94 95 96 97 @@ -95,6 +83,128 @@ def test_fmt_series( assert out == expected +@pytest.mark.parametrize( + ("values", "dtype", "expected"), + [ + ( + [-127, -1, 0, 1, 127], + pl.Int8, + """shape: (5,) +Series: 'foo' [i8] +[ + -127 + -1 + 0 + 1 + 127 +]""", + ), + ( + [-32768, -1, 0, 1, 32767], + pl.Int16, + """shape: (5,) +Series: 'foo' [i16] +[ + -32,768 + -1 + 0 + 1 + 32,767 +]""", + ), + ( + [-2147483648, -1, 0, 1, 2147483647], + pl.Int32, + """shape: (5,) +Series: 'foo' [i32] +[ + -2,147,483,648 + -1 + 0 + 1 + 2,147,483,647 +]""", + ), + ( + [-9223372036854775808, -1, 0, 1, 9223372036854775807], + pl.Int64, + """shape: (5,) +Series: 'foo' [i64] +[ + -9,223,372,036,854,775,808 + -1 + 0 + 1 + 9,223,372,036,854,775,807 +]""", + ), + ], +) +def test_fmt_signed_int_thousands_sep( + values: list[int], dtype: PolarsDataType, expected: str +) -> None: + s = pl.Series(name="foo", values=values, dtype=dtype) + with pl.Config(thousands_separator=True): + assert str(s) == expected + + +@pytest.mark.parametrize( + ("values", "dtype", "expected"), + [ + ( + [0, 1, 127], + pl.UInt8, + """shape: (3,) +Series: 'foo' [u8] +[ + 0 + 1 + 127 +]""", + ), + ( + [0, 1, 32767], + pl.UInt16, + """shape: (3,) +Series: 'foo' [u16] +[ + 0 + 1 + 32,767 +]""", + ), + ( + [0, 1, 2147483647], + pl.UInt32, + """shape: (3,) +Series: 'foo' [u32] +[ + 0 + 1 + 2,147,483,647 +]""", + ), + ( + [0, 1, 9223372036854775807], + pl.UInt64, + """shape: (3,) +Series: 'foo' [u64] +[ + 0 + 1 + 9,223,372,036,854,775,807 +]""", + ), + ], +) +def test_fmt_unsigned_int_thousands_sep( + values: list[int], dtype: PolarsDataType, expected: str +) -> None: + s = pl.Series(name="foo", values=values, dtype=dtype) + with pl.Config(thousands_separator=True): + assert str(s) == expected + + def test_fmt_float(capfd: pytest.CaptureFixture[str]) -> None: s = pl.Series(name="foo", values=[7.966e-05, 7.9e-05, 8.4666e-05, 8.00007966]) print(s) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 043cd0144fed..dcc387d060fd 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -853,45 +853,6 @@ def test_float_floor_divide() -> None: assert ldf_res == x // step -def test_lazy_ufunc() -> None: - ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) - out = ldf.select( - [ - np.power(cast(Any, pl.col("a")), 2).alias("power_uint8"), - np.power(cast(Any, pl.col("a")), 2.0).alias("power_float64"), - np.power(cast(Any, pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"), - ] - ) - expected = pl.DataFrame( - [ - pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), - pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), - pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), - ] - ) - assert_frame_equal(out.collect(), expected) - - -def test_lazy_ufunc_expr_not_first() -> None: - """Check numpy ufunc expressions also work if expression not the first argument.""" - ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) - out = ldf.select( - [ - np.power(2.0, cast(Any, pl.col("a"))).alias("power"), - (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), - (np.array([2, 2, 2]) / cast(Any, pl.col("a"))).alias("divide_array"), - ] - ) - expected = pl.DataFrame( - [ - pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), - pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), - pl.Series("divide_array", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), - ] - ) - assert_frame_equal(out.collect(), expected) - - def test_argminmax() -> None: ldf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": [1, 1, 2, 2, 2]}) out = ldf.select( @@ -1106,23 +1067,6 @@ def test_self_join() -> None: } -def test_preservation_of_subclasses() -> None: - """Test for LazyFrame inheritance.""" - - # We should be able to inherit from polars.LazyFrame - class SubClassedLazyFrame(pl.LazyFrame): - pass - - # The constructor creates an object which is an instance of both the - # superclass and subclass - ldf = pl.DataFrame({"column_1": [1, 2, 3]}).lazy() - ldf.__class__ = SubClassedLazyFrame - extended_ldf = ldf.with_columns(pl.lit(1).alias("column_2")) - - assert isinstance(extended_ldf, pl.LazyFrame) - assert isinstance(extended_ldf, SubClassedLazyFrame) - - def test_group_lengths() -> None: ldf = pl.LazyFrame( { diff --git a/py-polars/tests/unit/test_polars_import.py b/py-polars/tests/unit/test_polars_import.py index f9289ea8dda4..a0081b02a2e5 100644 --- a/py-polars/tests/unit/test_polars_import.py +++ b/py-polars/tests/unit/test_polars_import.py @@ -28,11 +28,11 @@ def _import_timings() -> bytes: # assemble suitable command to get polars module import timing; # run in a separate process to ensure clean timing results. cmd = f'{sys.executable} -S -X importtime -c "import polars"' - return ( - subprocess.run(cmd, shell=True, capture_output=True) - .stderr.replace(b"import time:", b"") - .strip() - ) + output = subprocess.run(cmd, shell=True, capture_output=True).stderr + if b"Traceback" in output: + msg = f"measuring import timings failed\n\nCommand output:\n{output.decode()}" + raise RuntimeError(msg) + return output.replace(b"import time:", b"").strip() def _import_timings_as_frame(n_tries: int) -> tuple[pl.DataFrame, int]: diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 0a8c180806a1..bbba9b6579b0 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -389,3 +389,22 @@ def test_rolling_key_projected_13617() -> None: assert r'DF ["idx", "value"]; PROJECT 2/2 COLUMNS' in plan out = ldf.collect(projection_pushdown=True) assert out.to_dict(as_series=False) == {"value": [["a"], ["b"]]} + + +def test_projection_drop_with_series_lit_14382() -> None: + df = pl.DataFrame({"b": [1, 6, 8, 7]}) + df2 = pl.DataFrame({"a": [1, 2, 4, 4], "b": [True, True, True, False]}) + + q = ( + df2.lazy() + .select( + *["a", "b"], pl.lit("b").alias("b_name"), df.get_column("b").alias("b_old") + ) + .filter(pl.col("b").not_()) + .drop("b") + ) + assert q.collect().to_dict(as_series=False) == { + "a": [4], + "b_name": ["b"], + "b_old": [7], + } diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 1a28b608ae06..83bf5f1b1735 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import date, datetime, time, timedelta from typing import Any import numpy as np @@ -368,3 +368,37 @@ def test_shift_drop_nulls_10875() -> None: assert pl.LazyFrame({"a": [1, 2, 3]}).shift(1).drop_nulls().collect()[ "a" ].to_list() == [1, 2] + + +def test_temporal_downcasts() -> None: + s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us")) + + assert s.to_list() == [ + datetime(1969, 12, 31, 23, 59, 59, 999999), + datetime(1970, 1, 1), + datetime(1970, 1, 1, 0, 0, 0, 1), + ] + + # downcast (from us to ms, or from datetime to date) should NOT change the date + for s_dt in (s.dt.date(), s.cast(pl.Date)): + assert s_dt.to_list() == [ + date(1969, 12, 31), + date(1970, 1, 1), + date(1970, 1, 1), + ] + assert s.cast(pl.Datetime("ms")).to_list() == [ + datetime(1969, 12, 31, 23, 59, 59, 999000), + datetime(1970, 1, 1), + datetime(1970, 1, 1), + ] + + +def test_temporal_time_casts() -> None: + s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us")) + + for s_dt in (s.dt.time(), s.cast(pl.Time)): + assert s_dt.to_list() == [ + time(23, 59, 59, 999999), + time(0, 0, 0, 0), + time(0, 0, 0, 1), + ] diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index f5eb9a8e4b57..d3ee82446efe 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -629,6 +629,6 @@ def test_literal_subtract_schema_13284() -> None: assert ( pl.LazyFrame({"a": [23, 30]}, schema={"a": pl.UInt8}) .with_columns(pl.col("a") - pl.lit(1)) - .group_by(by="a") + .group_by("a") .len() ).schema == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)]) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index 0bc4250afbf3..a61da1fcfea7 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -504,14 +504,14 @@ def test_regex_expansion_group_by_9947() -> None: def test_regex_expansion_exclude_10002() -> None: df = pl.DataFrame({"col_1": [1, 2, 3], "col_2": [2, 4, 3]}) - expected = {"col_1": [10, 20, 30], "col_2": [0.2, 0.4, 0.3]} + expected = pl.DataFrame({"col_1": [10, 20, 30], "col_2": [0.2, 0.4, 0.3]}) - assert ( + assert_frame_equal( df.select( pl.col("^col_.*$").exclude("col_2").mul(10), pl.col("^col_.*$").exclude("col_1") / 10, - ).to_dict(as_series=False) - == expected + ), + expected, ) diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 6fb79230e89b..99cec90a13dd 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -33,13 +33,10 @@ def test_lazyframe_serde() -> None: def test_serde_time_unit() -> None: - assert pickle.loads( - pickle.dumps( - pl.Series( - [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] - ).cast(pl.Datetime("ns")) - ) - ).dtype == pl.Datetime("ns") + values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] + s = pl.Series(values).cast(pl.Datetime("ns")) + result = pickle.loads(pickle.dumps(s)) + assert result.dtype == pl.Datetime("ns") def test_serde_duration() -> None: @@ -103,14 +100,6 @@ def test_deser_empty_list() -> None: assert s.to_list() == [[[42.0]], []] -def test_expression_json() -> None: - e = pl.col("foo").sum().over("bar") - json = e.meta.write_json() - - round_tripped = pl.Expr.from_json(json) - assert round_tripped.meta == e - - def times2(x: pl.Series) -> pl.Series: return x * 2 @@ -205,9 +194,40 @@ def test_serde_array_dtype() -> None: assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s) +def test_expr_serialization_roundtrip() -> None: + expr = pl.col("foo").sum().over("bar") + json = expr.meta.serialize() + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr + + +def test_expr_deserialize_file_not_found() -> None: + with pytest.raises(FileNotFoundError): + pl.Expr.deserialize("abcdef") + + +def test_expr_deserialize_invalid_json() -> None: + with pytest.raises( + pl.ComputeError, match="could not deserialize input into an expression" + ): + pl.Expr.deserialize(io.StringIO("abcdef")) + + +def test_expr_write_json_from_json_deprecated() -> None: + expr = pl.col("foo").sum().over("bar") + + with pytest.deprecated_call(): + json = expr.meta.write_json() + + with pytest.deprecated_call(): + round_tripped = pl.Expr.from_json(json) + + assert round_tripped.meta == expr + + def test_expression_json_13991() -> None: - e = pl.col("foo").cast(pl.Decimal) - json = e.meta.write_json() + expr = pl.col("foo").cast(pl.Decimal) + json = expr.meta.serialize() - round_tripped = pl.Expr.from_json(json) - assert round_tripped.meta == e + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr diff --git a/py-polars/tests/unit/test_simplify.py b/py-polars/tests/unit/test_simplify.py new file mode 100644 index 000000000000..e7bc7ec819e3 --- /dev/null +++ b/py-polars/tests/unit/test_simplify.py @@ -0,0 +1,10 @@ +import polars as pl + + +def test_flatten_alias() -> None: + assert ( + """len().alias("bar")""" + in pl.LazyFrame({"a": [1, 2]}) + .select(pl.len().alias("foo").alias("bar")) + .explain() + ) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index a5d00abc6eb9..bf8727c178a1 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -10,6 +10,7 @@ from polars.testing import assert_frame_equal, assert_frame_not_equal nan = float("nan") +pytest_plugins = ["pytester"] @pytest.mark.parametrize( @@ -366,3 +367,66 @@ def test_assert_frame_not_equal() -> None: df = pl.DataFrame({"a": [1, 2]}) with pytest.raises(AssertionError, match="frames are equal"): assert_frame_not_equal(df, df) + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal + +def test_frame_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 3]}) + assert_frame_equal(df1, df2) + +def test_frame_not_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_not_equal(df1, df2) + +def test_frame_data_type_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = {"a": [1, 2]} + assert_frame_equal(df1, df2) + +def test_frame_schema_fail(): + df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64}) + df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32}) + assert_frame_equal(df1, df2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=4) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_frame_equal" not in stdout + assert "def assert_frame_not_equal" not in stdout + assert "def _assert_correct_input_type" not in stdout + assert "def _assert_frame_schema_equal" not in stdout + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert ( + "AssertionError: DataFrames are different (value mismatch for column 'a')" + in stdout + ) + assert "AssertionError: frames are equal" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout + assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 0def3fc26c3f..e676be77b1ac 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -11,6 +11,7 @@ from polars.testing import assert_series_equal, assert_series_not_equal nan = float("nan") +pytest_plugins = ["pytester"] def test_compare_series_value_mismatch() -> None: @@ -619,24 +620,16 @@ def test_series_equal_nested_lengths_mismatch() -> None: assert_series_equal(s1, s2) -def test_series_equal_decimals_exact() -> None: - s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal) - s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal) - with pytest.raises(AssertionError, match="exact value mismatch"): - assert_series_equal(s1, s2, check_exact=True) - - -def test_series_equal_decimals_inexact() -> None: +@pytest.mark.parametrize("check_exact", [True, False]) +def test_series_equal_decimals(check_exact: bool) -> None: s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal) s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal) - assert_series_equal(s1, s2, check_exact=False) + assert_series_equal(s1, s1, check_exact=check_exact) + assert_series_equal(s2, s2, check_exact=check_exact) -def test_series_equal_decimals_inexact_fail() -> None: - s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal) - s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal) - with pytest.raises(AssertionError, match="value mismatch"): - assert_series_equal(s1, s2, check_exact=False, rtol=0) + with pytest.raises(AssertionError, match="exact value mismatch"): + assert_series_equal(s1, s2, check_exact=check_exact) def test_assert_series_equal_w_large_integers_12328() -> None: @@ -644,3 +637,81 @@ def test_assert_series_equal_w_large_integers_12328() -> None: right = pl.Series([1577840521123543]) with pytest.raises(AssertionError): assert_series_equal(left, right) + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_series_equal, assert_series_not_equal + +nan = float("nan") + +def test_series_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 3]) + assert_series_equal(s1, s2) + +def test_series_not_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 2]) + assert_series_not_equal(s1, s2) + +def test_series_nested_fail(): + s1 = pl.Series([[1, 2], [3, 4]]) + s2 = pl.Series([[1, 2], [3, 5]]) + assert_series_equal(s1, s2) + +def test_series_null_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, None]) + assert_series_equal(s1, s2) + +def test_series_nan_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, nan]) + assert_series_equal(s1, s2) + +def test_series_float_tolerance_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, 2.1]) + assert_series_equal(s1, s2) + +def test_series_schema_fail(): + s1 = pl.Series([1, 2], dtype=pl.Int64) + s2 = pl.Series([1, 2], dtype=pl.Int32) + assert_series_equal(s1, s2) + +def test_series_data_type_fail(): + s1 = pl.Series([1, 2]) + s2 = [1, 2] + assert_series_equal(s1, s2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=8) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert "AssertionError: Series are different (exact value mismatch)" in stdout + assert "AssertionError: Series are equal" in stdout + assert "AssertionError: Series are different (nan value mismatch)" in stdout + assert "AssertionError: Series are different (dtype mismatch)" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout diff --git a/py-polars/tests/unit/utils/test_build_info.py b/py-polars/tests/unit/utils/test_build_info.py deleted file mode 100644 index cd9f73a40a66..000000000000 --- a/py-polars/tests/unit/utils/test_build_info.py +++ /dev/null @@ -1,9 +0,0 @@ -import polars as pl - - -def test_build_info() -> None: - build_info = pl.build_info() - assert "version" in build_info # version is always present - features = build_info.get("features", {}) - if features: # only when compiled with `build_info` feature gate - assert "BUILD_INFO" in features diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py index fc84cc3d59f8..1b5e79a40586 100644 --- a/py-polars/tests/unit/utils/test_utils.py +++ b/py-polars/tests/unit/utils/test_utils.py @@ -32,13 +32,18 @@ @pytest.mark.parametrize( ("dt", "time_unit", "expected"), [ - (datetime(2121, 1, 1), "ns", 4765132800000000000), - (datetime(2121, 1, 1), "us", 4765132800000000), - (datetime(2121, 1, 1), "ms", 4765132800000), + (datetime(2121, 1, 1), "ns", 4_765_132_800_000_000_000), + (datetime(2121, 1, 1), "us", 4_765_132_800_000_000), + (datetime(2121, 1, 1), "ms", 4_765_132_800_000), + (datetime(2121, 1, 1), None, 4_765_132_800_000_000), + (datetime.min, "ns", -62_135_596_800_000_000_000), + (datetime.max, "ns", 253_402_300_799_999_999_000), + (datetime.min, "ms", -62_135_596_800_000), + (datetime.max, "ms", 253_402_300_799_999), ], ) def test_datetime_to_pl_timestamp( - dt: datetime, time_unit: TimeUnit, expected: int + dt: datetime, time_unit: TimeUnit | None, expected: int ) -> None: out = _datetime_to_pl_timestamp(dt, time_unit) assert out == expected @@ -47,31 +52,47 @@ def test_datetime_to_pl_timestamp( @pytest.mark.parametrize( ("t", "expected"), [ - (time(0, 0, 0), 0), (time(0, 0, 1), 1_000_000_000), (time(20, 52, 10), 75_130_000_000_000), (time(20, 52, 10, 200), 75_130_000_200_000), + (time.min, 0), + (time.max, 86_399_999_999_000), ], ) def test_time_to_pl_time(t: time, expected: int) -> None: assert _time_to_pl_time(t) == expected -def test_date_to_pl_date() -> None: - d = date(1999, 9, 9) - out = _date_to_pl_date(d) - assert out == 10843 +@pytest.mark.parametrize( + ("d", "expected"), + [ + (date(1999, 9, 9), 10_843), + (date(1969, 12, 31), -1), + (date.min, -719_162), + (date.max, 2_932_896), + ], +) +def test_date_to_pl_date(d: date, expected: int) -> None: + assert _date_to_pl_date(d) == expected -def test_timedelta_to_pl_timedelta() -> None: - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ns") - assert out == 86_400_000_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "us") - assert out == 86_400_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ms") - assert out == 86_400_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), time_unit=None) - assert out == 86_400_000_000 +@pytest.mark.parametrize( + ("td", "time_unit", "expected"), + [ + (timedelta(days=1), "ns", 86_400_000_000_000), + (timedelta(days=1), "us", 86_400_000_000), + (timedelta(days=1), "ms", 86_400_000), + (timedelta(days=1), None, 86_400_000_000), + (timedelta.min, "ns", -86_399_999_913_600_000_000_000), + (timedelta.max, "ns", 86_399_999_999_999_999_999_000), + (timedelta.min, "ms", -86_399_999_913_600_000), + (timedelta.max, "ms", 86_399_999_999_999_999), + ], +) +def test_timedelta_to_pl_timedelta( + td: timedelta, time_unit: TimeUnit | None, expected: int +) -> None: + assert _timedelta_to_pl_timedelta(td, time_unit) == expected @pytest.mark.parametrize( diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f1b98f9ea712..5f75e2c9af81 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-24" +channel = "nightly-2024-02-23"