8000 fix: Product operation produces float result for all input types by TrevorBergeron · Pull Request #501 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
2 changes: 1 addition & 1 deletion bigframes/core/compile/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _(
.else_(magnitude * pow(-1, negative_count_parity))
.end()
)
return float_result.cast(column.type()) # type: ignore
return float_result


@compile_unary_agg.register
Expand Down
5 changes: 1 addition & 4 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ class ProductOp(UnaryAggregateOp):
name: ClassVar[str] = "product"

def output_type(self, *input_types: dtypes.ExpressionType):
if pd.api.types.is_bool_dtype(input_types[0]):
return dtypes.INT_DTYPE
else:
return input_types[0]
return dtypes.FLOAT_DTYPE


@dataclasses.dataclass(frozen=True)
Expand Down
3 changes: 1 addition & 2 deletions tests/system/small/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def test_dataframe_groupby_multi_sum(
(lambda x: x.cumsum(numeric_only=True)),
(lambda x: x.cummax(numeric_only=True)),
(lambda x: x.cummin(numeric_only=True)),
# pandas 2.2 uses floating point for cumulative product even for
# integer inputs.
# Pre-pandas 2.2 doesn't always proeduce float.
(lambda x: x.cumprod().astype("Float64")),
(lambda x: x.shift(periods=2)),
],
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ def test_groupby_prod(scalars_dfs):
bf_series = scalars_df[col_name].groupby(scalars_df["int64_col"]).prod()
pd_series = (
scalars_pandas_df[col_name].groupby(scalars_pandas_df["int64_col"]).prod()
)
).astype(pd.Float64Dtype())
# TODO(swast): Update groupby to use index based on group by key(s).
bf_result = bf_series.to_pandas()
assert_series_equal(
Expand Down
8 changes: 4 additions & 4 deletions third_party/bigframes_vendored/pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4416,10 +4416,10 @@ def cumprod(self) -> DataFrame:
[3 rows x 2 columns]

>>> df.cumprod()
A B
0 3 1
1 3 2
2 6 6
A B
0 3.0 1.0
1 3.0 2.0
2 6.0 6.0
<BLANKLINE>
[3 rows x 2 columns]

Expand Down
0