diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index f86ba6ddc8..46d71a079e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -61,7 +61,7 @@ import bigframes._config.bigquery_options as bigquery_options import bigframes.clients import bigframes.constants -from bigframes.core import blocks, log_adapter +from bigframes.core import blocks, log_adapter, utils import bigframes.core.pyformat # Even though the ibis.backends.bigquery import is unused, it's needed @@ -1108,11 +1108,8 @@ def _read_csv_w_bigquery_engine( native CSV loading capabilities, making it suitable for large datasets that may not fit into local memory. """ - if dtype is not None: - raise NotImplementedError( - f"BigQuery engine does not support the `dtype` argument." - f"{constants.FEEDBACK_LINK}" - ) + if dtype is not None and not utils.is_dict_like(dtype): + raise ValueError("dtype should be a dict-like object.") if names is not None: if len(names) != len(set(names)): @@ -1167,10 +1164,16 @@ def _read_csv_w_bigquery_engine( job_config.skip_leading_rows = header + 1 table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config) - return self._loader.read_gbq_table( + df = self._loader.read_gbq_table( table_id, index_col=index_col, columns=columns, names=names ) + if dtype is not None: + for column, dtype in dtype.items(): + if column in df.columns: + df[column] = df[column].astype(dtype) + return df + def read_pickle( self, filepath_or_buffer: FilePath | ReadPickleBuffer, diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index ce5d3d66b6..8b0a1266ce 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -663,9 +663,10 @@ def read_gbq_table( renamed_cols: Dict[str, str] = { col: new_name for col, new_name in zip(array_value.column_ids, names) } - index_names = [ - renamed_cols.get(index_col, index_col) for index_col in index_cols - ] + if index_col != bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64: + index_names = [ + renamed_cols.get(index_col, index_col) for index_col in index_cols + ] value_columns = [renamed_cols.get(col, col) for col in value_columns] block = blocks.Block( diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 2a58061607..dfb69d628e 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -1369,6 +1369,45 @@ def test_read_csv_for_names_and_index_col( ) +def test_read_csv_for_dtype(session, df_and_gcs_csv_for_two_columns): + _, path = df_and_gcs_csv_for_two_columns + + dtype = {"bool_col": pd.BooleanDtype(), "int64_col": pd.Float64Dtype()} + bf_df = session.read_csv(path, engine="bigquery", dtype=dtype) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + pd_df = session.read_csv(path, dtype=dtype) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + + # BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs + # (b/280889935) or guarantee row ordering. + bf_df = bf_df.set_index("rowindex").sort_index() + pd_df = pd_df.set_index("rowindex") + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + +def test_read_csv_for_dtype_w_names(session, df_and_gcs_csv_for_two_columns): + _, path = df_and_gcs_csv_for_two_columns + + names = ["a", "b", "c"] + dtype = {"b": pd.BooleanDtype(), "c": pd.Float64Dtype()} + bf_df = session.read_csv(path, engine="bigquery", names=names, dtype=dtype) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + pd_df = session.read_csv(path, names=names, dtype=dtype) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + + # BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs + # (b/280889935) or guarantee row ordering. + bf_df = bf_df.set_index("a").sort_index() + pd_df = pd_df.set_index("a") + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + @pytest.mark.parametrize( ("kwargs", "match"), [ diff --git a/tests/unit/session/test_session.py b/tests/unit/session/test_session.py index dc8ee2c0d9..cbd31f588a 100644 --- a/tests/unit/session/test_session.py +++ b/tests/unit/session/test_session.py @@ -108,11 +108,6 @@ @pytest.mark.parametrize( ("kwargs", "match"), [ - pytest.param( - {"engine": "bigquery", "dtype": {}}, - "BigQuery engine does not support the `dtype` argument", - id="with_dtype", - ), pytest.param( {"engine": "bigquery", "usecols": [1, 2]}, "BigQuery engine only supports an iterable of strings for `usecols`.", @@ -215,6 +210,17 @@ def test_read_csv_w_bigquery_engine_raises_error_for_invalid_names( session.read_csv("path/to/csv.csv", engine="bigquery", names=names) +def test_read_csv_w_bigquery_engine_raises_error_for_invalid_dtypes(): + session = mocks.create_bigquery_session() + + with pytest.raises(ValueError, match="dtype should be a dict-like object."): + session.read_csv( + "path/to/csv.csv", + engine="bigquery", + dtype=["a", "b", "c"], # type: ignore[arg-type] + ) + + @pytest.mark.parametrize("missing_parts_table_id", [(""), ("table")]) def test_read_gbq_missing_parts(missing_parts_table_id): session = mocks.create_bigquery_session()