diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index dbbf9ee864..c8632ebc8c 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -29,7 +29,17 @@ import random import textwrap import typing -from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import warnings import bigframes_vendored.constants as constants @@ -87,14 +97,22 @@ LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]] -class BlockHolder(typing.Protocol): +@dataclasses.dataclass +class PandasBatches(Iterator[pd.DataFrame]): """Interface for mutable objects with state represented by a block value object.""" - def _set_block(self, block: Block): - """Set the underlying block value of the object""" + def __init__( + self, pandas_batches: Iterator[pd.DataFrame], total_rows: Optional[int] = 0 + ): + self._dataframes: Iterator[pd.DataFrame] = pandas_batches + self._total_rows: Optional[int] = total_rows + + @property + def total_rows(self) -> Optional[int]: + return self._total_rows - def _get_block(self) -> Block: - """Get the underlying block value of the object""" + def __next__(self) -> pd.DataFrame: + return next(self._dataframes) @dataclasses.dataclass() @@ -599,8 +617,7 @@ def try_peek( self.expr, n, use_explicit_destination=allow_large_results ) df = result.to_pandas() - self._copy_index_to_pandas(df) - return df + return self._copy_index_to_pandas(df) else: return None @@ -609,8 +626,7 @@ def to_pandas_batches( page_size: Optional[int] = None, max_results: Optional[int] = None, allow_large_results: Optional[bool] = None, - squeeze: Optional[bool] = False, - ): + ) -> Iterator[pd.DataFrame]: """Download results one message at a time. page_size and max_results determine the size and number of batches, @@ -621,43 +637,43 @@ def to_pandas_batches( use_explicit_destination=allow_large_results, ) - total_batches = 0 - for df in execute_result.to_pandas_batches( - page_size=page_size, max_results=max_results - ): - total_batches += 1 - self._copy_index_to_pandas(df) - if squeeze: - yield df.squeeze(axis=1) - else: - yield df - # To reduce the number of edge cases to consider when working with the # results of this, always return at least one DataFrame. See: # b/428918844. - if total_batches == 0: - df = pd.DataFrame( - { - col: pd.Series([], dtype=self.expr.get_column_type(col)) - for col in itertools.chain(self.value_columns, self.index_columns) - } - ) - self._copy_index_to_pandas(df) - yield df + empty_val = pd.DataFrame( + { + col: pd.Series([], dtype=self.expr.get_column_type(col)) + for col in itertools.chain(self.value_columns, self.index_columns) + } + ) + dfs = map( + lambda a: a[0], + itertools.zip_longest( + execute_result.to_pandas_batches(page_size, max_results), + [0], + fillvalue=empty_val, + ), + ) + dfs = iter(map(self._copy_index_to_pandas, dfs)) - def _copy_index_to_pandas(self, df: pd.DataFrame): - """Set the index on pandas DataFrame to match this block. + total_rows = execute_result.total_rows + if (total_rows is not None) and (max_results is not None): + total_rows = min(total_rows, max_results) - Warning: This method modifies ``df`` inplace. - """ + return PandasBatches(dfs, total_rows) + + def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame: + """Set the index on pandas DataFrame to match this block.""" # Note: If BigQuery DataFrame has null index, a default one will be created for the local materialization. + new_df = df.copy() if len(self.index_columns) > 0: - df.set_index(list(self.index_columns), inplace=True) + new_df.set_index(list(self.index_columns), inplace=True) # Pandas names is annotated as list[str] rather than the more # general Sequence[Label] that BigQuery DataFrames has. # See: https://github.com/pandas-dev/pandas-stubs/issues/804 - df.index.names = self.index.names # type: ignore - df.columns = self.column_labels + new_df.index.names = self.index.names # type: ignore + new_df.columns = self.column_labels + return new_df def _materialize_local( self, materialize_options: MaterializationOptions = MaterializationOptions() @@ -724,9 +740,7 @@ def _materialize_local( ) else: df = execute_result.to_pandas() - self._copy_index_to_pandas(df) - - return df, execute_result.query_job + return self._copy_index_to_pandas(df), execute_result.query_job def _downsample( self, total_rows: int, sampling_method: str, fraction: float, random_state @@ -1591,8 +1605,7 @@ def retrieve_repr_request_results( row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar() head_df = head_result.to_pandas() - self._copy_index_to_pandas(head_df) - return head_df, row_count, head_result.query_job + return self._copy_index_to_pandas(head_df), row_count, head_result.query_job def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]: expr, result_id = self._expr.promote_offsets() diff --git a/bigframes/series.py b/bigframes/series.py index ebc2913f78..d7833fef2a 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -648,13 +648,12 @@ def to_pandas_batches( form the original Series. Results stream from bigquery, see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.table.RowIterator#google_cloud_bigquery_table_RowIterator_to_arrow_iterable """ - df = self._block.to_pandas_batches( + batches = self._block.to_pandas_batches( page_size=page_size, max_results=max_results, allow_large_results=allow_large_results, - squeeze=True, ) - return df + return map(lambda df: cast(pandas.Series, df.squeeze(1)), batches) def _compute_dry_run(self) -> bigquery.QueryJob: _, query_job = self._block._compute_dry_run((self._value_column,)) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 91a83dfd73..3a1c54302f 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -871,6 +871,21 @@ def test_filter_df(scalars_dfs): assert_pandas_df_equal(bf_result, pd_result) +def test_df_to_pandas_batches(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + capped_unfiltered_batches = scalars_df.to_pandas_batches(page_size=2, max_results=6) + bf_bool_series = scalars_df["bool_col"] + filtered_batches = scalars_df[bf_bool_series].to_pandas_batches() + + pd_bool_series = scalars_pandas_df["bool_col"] + pd_result = scalars_pandas_df[pd_bool_series] + + assert 6 == capped_unfiltered_batches.total_rows + assert len(pd_result) == filtered_batches.total_rows + assert_pandas_df_equal(pd.concat(filtered_batches), pd_result) + + def test_assign_new_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs kwargs = {"new_col": 2}