diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 9c0953c35..3ed6d40fe 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -21,6 +21,7 @@ from __future__ import annotations + from typing import Any, Iterable, List, Literal, TYPE_CHECKING from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated @@ -267,6 +268,18 @@ def sort(self, *exprs: Expr | SortExpr) -> DataFrame: exprs_raw = [sort_or_default(expr) for expr in exprs] return DataFrame(self.df.sort(*exprs_raw)) + def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: + """Cast one or more columns to a different data type. + + Args: + mapping: Mapped with column as key and column dtype as value. + + Returns: + DataFrame after casting columns + """ + exprs = [Expr.column(col).cast(dtype) for col, dtype in mapping.items()] + return self.with_columns(exprs) + def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 0d4a7dcb0..bb408c9c9 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -247,6 +247,15 @@ def test_with_columns(df): assert result.column(6) == pa.array([5, 7, 9]) +def test_cast(df): + df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) + expected = pa.schema( + [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())] + ) + + assert df.schema() == expected + + def test_with_column_renamed(df): df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")