From 14424d2c7d5ed991ee09acb6d07fa170f8ca0dcd Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 9 Jul 2025 17:48:43 +0000 Subject: [PATCH 1/2] fix: Used query row count metadata instead of table metadata --- bigframes/core/nodes.py | 2 +- bigframes/session/bq_caching_executor.py | 21 +++++++++++++-------- tests/system/small/test_dataframe.py | 14 ++++---------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 205621fee2..5d3c814437 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -161,7 +161,7 @@ def is_noop(self) -> bool: return ( ((not self.start) or (self.start == 0)) and (self.step == 1) - and ((self.stop is None) or (self.stop == self.row_count)) + and ((self.stop is None) or (self.stop == self.child.row_count)) ) @property diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 6750652bc2..a970e75a0f 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -100,6 +100,7 @@ def cache_results_table( original_root: nodes.BigFrameNode, table: bigquery.Table, ordering: order.RowOrdering, + num_rows: Optional[int] = None, ): # Assumption: GBQ cached table uses field name as bq column name scan_list = nodes.ScanList( @@ -112,7 +113,7 @@ def cache_results_table( source=nodes.BigqueryDataSource( nodes.GbqTable.from_table(table), ordering=ordering, - n_rows=table.num_rows, + n_rows=num_rows, ), scan_list=scan_list, table_session=original_root.session, @@ -468,14 +469,16 @@ def _cache_with_cluster_cols( plan, sort_rows=False, materialize_all_order_keys=True ) ) - tmp_table_ref = self._sql_as_cached_temp_table( + tmp_table_ref, num_rows = self._sql_as_cached_temp_table( compiled.sql, compiled.sql_schema, cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols), ) tmp_table = self.bqclient.get_table(tmp_table_ref) assert compiled.row_order is not None - self.cache.cache_results_table(array_value.node, tmp_table, compiled.row_order) + self.cache.cache_results_table( + array_value.node, tmp_table, compiled.row_order, num_rows=num_rows + ) def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): """Executes the query and uses the resulting table to rewrite future executions.""" @@ -487,14 +490,16 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): sort_rows=False, ) ) - tmp_table_ref = self._sql_as_cached_temp_table( + tmp_table_ref, num_rows = self._sql_as_cached_temp_table( compiled.sql, compiled.sql_schema, cluster_cols=[offset_column], ) tmp_table = self.bqclient.get_table(tmp_table_ref) assert compiled.row_order is not None - self.cache.cache_results_table(array_value.node, tmp_table, compiled.row_order) + self.cache.cache_results_table( + array_value.node, tmp_table, compiled.row_order, num_rows=num_rows + ) def _cache_with_session_awareness( self, @@ -552,7 +557,7 @@ def _sql_as_cached_temp_table( sql: str, schema: Sequence[bigquery.SchemaField], cluster_cols: Sequence[str], - ) -> bigquery.TableReference: + ) -> tuple[bigquery.TableReference, Optional[int]]: assert len(cluster_cols) <= _MAX_CLUSTER_COLUMNS temp_table = self.storage_manager.create_temp_table(schema, cluster_cols) @@ -567,8 +572,8 @@ def _sql_as_cached_temp_table( job_config=job_config, ) assert query_job is not None - query_job.result() - return query_job.destination + iter = query_job.result() + return query_job.destination, iter.total_rows def _validate_result_schema( self, diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 3a1c54302f..ec92194887 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -3438,18 +3438,12 @@ def test_loc_select_columns_w_repeats(scalars_df_index, scalars_pandas_df_index) ("start", "stop", "step"), [ (0, 0, None), - (None, None, None), - (1, None, None), - (None, 4, None), - (None, None, 2), - (None, 50000000000, 1), - (5, 4, None), - (3, None, 2), - (1, 7, 2), - (1, 7, 50000000000), ], ) -def test_iloc_slice(scalars_df_index, scalars_pandas_df_index, start, stop, step): +def test_iloc_slice_after_cache( + scalars_df_index, scalars_pandas_df_index, start, stop, step +): + scalars_df_index.cache() bf_result = scalars_df_index.iloc[start:stop:step].to_pandas() pd_result = scalars_pandas_df_index.iloc[start:stop:step] pd.testing.assert_frame_equal( From 73ea686651c1a7d6e155af6ac5c9151cc08e91d2 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 9 Jul 2025 19:26:26 +0000 Subject: [PATCH 2/2] add back test --- tests/system/small/test_dataframe.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index ec92194887..8ea1259325 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -3434,6 +3434,30 @@ def test_loc_select_columns_w_repeats(scalars_df_index, scalars_pandas_df_index) ) +@pytest.mark.parametrize( + ("start", "stop", "step"), + [ + (0, 0, None), + (None, None, None), + (1, None, None), + (None, 4, None), + (None, None, 2), + (None, 50000000000, 1), + (5, 4, None), + (3, None, 2), + (1, 7, 2), + (1, 7, 50000000000), + ], +) +def test_iloc_slice(scalars_df_index, scalars_pandas_df_index, start, stop, step): + bf_result = scalars_df_index.iloc[start:stop:step].to_pandas() + pd_result = scalars_pandas_df_index.iloc[start:stop:step] + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + @pytest.mark.parametrize( ("start", "stop", "step"), [