From ce1aa6799f18935fa022a298d4abaf841b705e76 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 1 May 2025 20:20:54 +0000 Subject: [PATCH] perf: Rechunk result pages client side --- bigframes/core/blocks.py | 6 +- bigframes/core/pyarrow_utils.py | 87 ++++++++++++++++++++++ bigframes/session/_io/bigquery/__init__.py | 10 +-- bigframes/session/bq_caching_executor.py | 23 +----- bigframes/session/executor.py | 25 ++++++- bigframes/session/loader.py | 2 - tests/unit/core/test_pyarrow_utils.py | 65 ++++++++++++++++ tests/unit/session/test_io_bigquery.py | 7 +- 8 files changed, 182 insertions(+), 43 deletions(-) create mode 100644 bigframes/core/pyarrow_utils.py create mode 100644 tests/unit/core/test_pyarrow_utils.py diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index cc3b70f8a8..f8107486c1 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -595,10 +595,10 @@ def to_pandas_batches( self.expr, ordered=True, use_explicit_destination=allow_large_results, - page_size=page_size, - max_results=max_results, ) - for df in execute_result.to_pandas_batches(): + for df in execute_result.to_pandas_batches( + page_size=page_size, max_results=max_results + ): self._copy_index_to_pandas(df) if squeeze: yield df.squeeze(axis=1) diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py new file mode 100644 index 0000000000..eead30d908 --- /dev/null +++ b/bigframes/core/pyarrow_utils.py @@ -0,0 +1,87 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Iterator + +import pyarrow as pa + + +class BatchBuffer: + """ + FIFO buffer of pyarrow Record batches + + Not thread-safe. + """ + + def __init__(self): + self._buffer: list[pa.RecordBatch] = [] + self._buffer_size: int = 0 + + def __len__(self): + return self._buffer_size + + def append_batch(self, batch: pa.RecordBatch) -> None: + self._buffer.append(batch) + self._buffer_size += batch.num_rows + + def take_as_batches(self, n: int) -> tuple[pa.RecordBatch, ...]: + if n > len(self): + raise ValueError(f"Cannot take {n} rows, only {len(self)} rows in buffer.") + rows_taken = 0 + sub_batches: list[pa.RecordBatch] = [] + while rows_taken < n: + batch = self._buffer.pop(0) + if batch.num_rows > (n - rows_taken): + sub_batches.append(batch.slice(length=n - rows_taken)) + self._buffer.insert(0, batch.slice(offset=n - rows_taken)) + rows_taken += n - rows_taken + else: + sub_batches.append(batch) + rows_taken += batch.num_rows + + self._buffer_size -= n + return tuple(sub_batches) + + def take_rechunked(self, n: int) -> pa.RecordBatch: + return ( + pa.Table.from_batches(self.take_as_batches(n)) + .combine_chunks() + .to_batches()[0] + ) + + +def chunk_by_row_count( + batches: Iterable[pa.RecordBatch], page_size: int +) -> Iterator[tuple[pa.RecordBatch, ...]]: + buffer = BatchBuffer() + for batch in batches: + buffer.append_batch(batch) + while len(buffer) >= page_size: + yield buffer.take_as_batches(page_size) + + # emit final page, maybe smaller + if len(buffer) > 0: + yield buffer.take_as_batches(len(buffer)) + + +def truncate_pyarrow_iterable( + batches: Iterable[pa.RecordBatch], max_results: int +) -> Iterator[pa.RecordBatch]: + total_yielded = 0 + for batch in batches: + if batch.num_rows >= (max_results - total_yielded): + yield batch.slice(length=max_results - total_yielded) + return + else: + yield batch + total_yielded += batch.num_rows diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index 6df9424e3b..48268d925d 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -222,8 +222,6 @@ def start_query_with_client( job_config: bigquery.job.QueryJobConfig, location: Optional[str] = None, project: Optional[str] = None, - max_results: Optional[int] = None, - page_size: Optional[int] = None, timeout: Optional[float] = None, api_name: Optional[str] = None, metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None, @@ -244,8 +242,6 @@ def start_query_with_client( location=location, project=project, api_timeout=timeout, - page_size=page_size, - max_results=max_results, ) if metrics is not None: metrics.count_job_stats(row_iterator=results_iterator) @@ -267,14 +263,10 @@ def start_query_with_client( if opts.progress_bar is not None and not query_job.configuration.dry_run: results_iterator = formatting_helpers.wait_for_query_job( query_job, - max_results=max_results, progress_bar=opts.progress_bar, - page_size=page_size, ) else: - results_iterator = query_job.result( - max_results=max_results, page_size=page_size - ) + results_iterator = query_job.result() if metrics is not None: metrics.count_job_stats(query_job=query_job) diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index ec5795f9a8..35a438403a 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -106,8 +106,6 @@ def execute( *, ordered: bool = True, use_explicit_destination: Optional[bool] = None, - page_size: Optional[int] = None, - max_results: Optional[int] = None, ) -> executor.ExecuteResult: if use_explicit_destination is None: use_explicit_destination = bigframes.options.bigquery.allow_large_results @@ -127,8 +125,6 @@ def execute( return self._execute_plan( plan, ordered=ordered, - page_size=page_size, - max_results=max_results, destination=destination_table, ) @@ -290,8 +286,6 @@ def _run_execute_query( sql: str, job_config: Optional[bq_job.QueryJobConfig] = None, api_name: Optional[str] = None, - page_size: Optional[int] = None, - max_results: Optional[int] = None, query_with_job: bool = True, ) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]: """ @@ -312,8 +306,6 @@ def _run_execute_query( sql, job_config=job_config, api_name=api_name, - max_results=max_results, - page_size=page_size, metrics=self.metrics, query_with_job=query_with_job, ) @@ -488,16 +480,13 @@ def _execute_plan( self, plan: nodes.BigFrameNode, ordered: bool, - page_size: Optional[int] = None, - max_results: Optional[int] = None, destination: Optional[bq_table.TableReference] = None, peek: Optional[int] = None, ): """Just execute whatever plan as is, without further caching or decomposition.""" # First try to execute fast-paths - # TODO: Allow page_size and max_results by rechunking/truncating results - if (not page_size) and (not max_results) and (not destination) and (not peek): + if (not destination) and (not peek): for semi_executor in self._semi_executors: maybe_result = semi_executor.execute(plan, ordered=ordered) if maybe_result: @@ -513,20 +502,12 @@ def _execute_plan( iterator, query_job = self._run_execute_query( sql=sql, job_config=job_config, - page_size=page_size, - max_results=max_results, query_with_job=(destination is not None), ) # Though we provide the read client, iterator may or may not use it based on what is efficient for the result def iterator_supplier(): - # Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154 - if iterator._page_size is not None or iterator.max_results is not None: - return iterator.to_arrow_iterable(bqstorage_client=None) - else: - return iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient - ) + return iterator.to_arrow_iterable(bqstorage_client=self.bqstoragereadclient) if query_job: size_bytes = self.bqclient.get_table(query_job.destination).num_bytes diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 0ba4ee3c2d..3ce29cda18 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -25,6 +25,7 @@ import pyarrow import bigframes.core +from bigframes.core import pyarrow_utils import bigframes.core.schema import bigframes.session._io.pandas as io_pandas @@ -55,10 +56,28 @@ def to_arrow_table(self) -> pyarrow.Table: def to_pandas(self) -> pd.DataFrame: return io_pandas.arrow_to_pandas(self.to_arrow_table(), self.schema) - def to_pandas_batches(self) -> Iterator[pd.DataFrame]: + def to_pandas_batches( + self, page_size: Optional[int] = None, max_results: Optional[int] = None + ) -> Iterator[pd.DataFrame]: + assert (page_size is None) or (page_size > 0) + assert (max_results is None) or (max_results > 0) + batch_iter: Iterator[ + Union[pyarrow.Table, pyarrow.RecordBatch] + ] = self.arrow_batches() + if max_results is not None: + batch_iter = pyarrow_utils.truncate_pyarrow_iterable( + batch_iter, max_results + ) + + if page_size is not None: + batches_iter = pyarrow_utils.chunk_by_row_count(batch_iter, page_size) + batch_iter = map( + lambda batches: pyarrow.Table.from_batches(batches), batches_iter + ) + yield from map( functools.partial(io_pandas.arrow_to_pandas, schema=self.schema), - self.arrow_batches(), + batch_iter, ) def to_py_scalar(self): @@ -96,8 +115,6 @@ def execute( *, ordered: bool = True, use_explicit_destination: Optional[bool] = False, - page_size: Optional[int] = None, - max_results: Optional[int] = None, ) -> ExecuteResult: """ Execute the ArrayValue, storing the result to a temporary session-owned table. diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index e6b24e016c..1d68c2c4f8 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -809,7 +809,6 @@ def _start_query( self, sql: str, job_config: Optional[google.cloud.bigquery.QueryJobConfig] = None, - max_results: Optional[int] = None, timeout: Optional[float] = None, api_name: Optional[str] = None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: @@ -828,7 +827,6 @@ def _start_query( self._bqclient, sql, job_config=job_config, - max_results=max_results, timeout=timeout, api_name=api_name, ) diff --git a/tests/unit/core/test_pyarrow_utils.py b/tests/unit/core/test_pyarrow_utils.py new file mode 100644 index 0000000000..155c36d268 --- /dev/null +++ b/tests/unit/core/test_pyarrow_utils.py @@ -0,0 +1,65 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import numpy as np +import pyarrow as pa +import pytest + +from bigframes.core import pyarrow_utils + +PA_TABLE = pa.table({f"col_{i}": np.random.rand(1000) for i in range(10)}) + +# 17, 3, 929 coprime +N = 17 +MANY_SMALL_BATCHES = PA_TABLE.to_batches(max_chunksize=3) +FEW_BIG_BATCHES = PA_TABLE.to_batches(max_chunksize=929) + + +@pytest.mark.parametrize( + ["batches", "page_size"], + [ + (MANY_SMALL_BATCHES, N), + (FEW_BIG_BATCHES, N), + ], +) +def test_chunk_by_row_count(batches, page_size): + results = list(pyarrow_utils.chunk_by_row_count(batches, page_size=page_size)) + + for i, batches in enumerate(results): + if i != len(results) - 1: + assert sum(map(lambda x: x.num_rows, batches)) == page_size + else: + # final page can be smaller + assert sum(map(lambda x: x.num_rows, batches)) <= page_size + + reconstructed = pa.Table.from_batches(itertools.chain.from_iterable(results)) + assert reconstructed.equals(PA_TABLE) + + +@pytest.mark.parametrize( + ["batches", "max_rows"], + [ + (MANY_SMALL_BATCHES, N), + (FEW_BIG_BATCHES, N), + ], +) +def test_truncate_pyarrow_iterable(batches, max_rows): + results = list( + pyarrow_utils.truncate_pyarrow_iterable(batches, max_results=max_rows) + ) + + reconstructed = pa.Table.from_batches(results) + assert reconstructed.equals(PA_TABLE.slice(length=max_rows)) diff --git a/tests/unit/session/test_io_bigquery.py b/tests/unit/session/test_io_bigquery.py index af2c7714ab..14e5d1c2fe 100644 --- a/tests/unit/session/test_io_bigquery.py +++ b/tests/unit/session/test_io_bigquery.py @@ -199,11 +199,11 @@ def test_add_and_trim_labels_length_limit_met(): @pytest.mark.parametrize( - ("max_results", "timeout", "api_name"), - [(None, None, None), (100, 30.0, "test_api")], + ("timeout", "api_name"), + [(None, None), (30.0, "test_api")], ) def test_start_query_with_client_labels_length_limit_met( - mock_bq_client, max_results, timeout, api_name + mock_bq_client, timeout, api_name ): sql = "select * from abc" cur_labels = { @@ -230,7 +230,6 @@ def test_start_query_with_client_labels_length_limit_met( mock_bq_client, sql, job_config, - max_results=max_results, timeout=timeout, api_name=api_name, )