8000 Upgrade to Datafusion 43 (#905) · kylebarron/datafusion-python@3c66201 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3c66201

Browse files
Upgrade to Datafusion 43 (apache#905)
* patch datafusion deps * migrate from deprecated RuntimeEnv::new to RuntimeEnv::try_new Ref: apache/datafusion#12566 * remove Arc from create_udf call Ref: apache/datafusion#12489 * doc typo * migrage new UnnestOptions API Ref: https://github.com/apache/datafusion/pull/12836/files * update API for logical expr Limit Ref: apache/datafusion#12836 * remove logical expr CrossJoin It was removed upstream. Ref: apache/datafusion#13076 * update PyWindowUDF Ref: apache/datafusion#12803 * migrate window functions lead and lag to udwf Ref: apache/datafusion#12802 * migrate window functions rank, dense_rank, and percent_rank to udwf Ref: apache/datafusion#12648 * convert window function cume_dist to udwf Ref: apache/datafusion#12695 * convert window function ntile to udwf Ref: apache/datafusion#12694 * clean up functions_window invocation * Only one column was being passed to udwf * Update to DF 43.0.0 * Update tests to look for string_view type * String view is now the default type for strings * Making a variety of adjustments in wrappers and unit tests to account for the switch from string to string_view as default * Resolve errors in doc building --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent 4a6c4d1 commit 3c66201

File tree

19 files changed

+338
-338
lines changed
  • src
  • 19 files changed

    +338
    -338
    lines changed

    Cargo.lock

    Lines changed: 199 additions & 174 deletions
    Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

    Cargo.toml

    Lines changed: 5 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -37,9 +37,10 @@ substrait = ["dep:datafusion-substrait"]
    3737
    tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
    3838
    pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
    3939
    arrow = { version = "53", features = ["pyarrow"] }
    40-
    datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
    41-
    datafusion-substrait = { version = "42.0.0", optional = true }
    42-
    datafusion-proto = { version = "42.0.0" }
    40+
    datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
    41+
    datafusion-substrait = { version = "43.0.0", optional = true }
    42+
    datafusion-proto = { version = "43.0.0" }
    43+
    datafusion-functions-window-common = { version = "43.0.0" }
    4344
    prost = "0.13" # keep in line with `datafusion-substrait`
    4445
    uuid = { version = "1.11", features = ["v4"] }
    4546
    mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
    @@ -58,4 +59,4 @@ crate-type = ["cdylib", "rlib"]
    5859

    5960
    [profile.release]
    6061
    lto = true
    61-
    codegen-units = 1
    62+
    codegen-units = 1

    examples/tpch/_tests.py

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -25,7 +25,7 @@
    2525
    def df_selection(col_name, col_type):
    2626
    if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
    2727
    return F.round(col(col_name), lit(2)).alias(col_name)
    28-
    elif col_type == pa.string():
    28+
    elif col_type == pa.string() or col_type == pa.string_view():
    2929
    return F.trim(col(col_name)).alias(col_name)
    3030
    else:
    3131
    return col(col_name)
    @@ -43,7 +43,7 @@ def load_schema(col_name, col_type):
    4343
    def expected_selection(col_name, col_type):
    4444
    if col_type == pa.int64() or col_type == pa.int32():
    4545
    return F.trim(col(col_name)).cast(col_type).alias(col_name)
    46-
    elif col_type == pa.string():
    46+
    elif col_type == pa.string() or col_type == pa.string_view():
    4747
    return F.trim(col(col_name)).alias(col_name)
    4848
    else:
    4949
    return col(col_name)

    python/datafusion/expr.py

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -51,7 +51,6 @@
    5151
    Column = expr_internal.Column
    5252
    CreateMemoryTable = expr_internal.CreateMemoryTable
    5353
    CreateView = expr_internal.CreateView
    54-
    CrossJoin = expr_internal.CrossJoin
    5554
    Distinct = expr_internal.Distinct
    5655
    DropTable = expr_internal.DropTable
    5756
    EmptyRelation = expr_internal.EmptyRelation
    @@ -140,7 +139,6 @@
    140139
    "Join",
    141140
    "JoinType",
    142141
    "JoinConstraint",
    143-
    "CrossJoin",
    144142
    "Union",
    145143
    "Unnest",
    146144
    "UnnestExpr",
    @@ -376,6 +374,8 @@ def literal(value: Any) -> Expr:
    376374
    377375
    ``value`` must be a valid PyArrow scalar value or easily castable to one.
    378376
    """
    377+
    if isinstance(value, str):
    378+
    value = pa.scalar(value, type=pa.string_view())
    379379
    if not isinstance(value, pa.Scalar):
    380380
    value = pa.scalar(value)
    381381
    return Expr(expr_internal.Expr.literal(value))

    python/datafusion/functions.py

    Lines changed: 8 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -297,7 +297,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:
    297297

    298298
    def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
    299299
    """Converts each element to its text representation."""
    300-
    return Expr(f.array_to_string(expr.expr, delimiter.expr))
    300+
    return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))
    301301

    302302

    303303
    def array_join(expr: Expr, delimiter: Expr) -> Expr:
    @@ -1067,7 +1067,10 @@ def struct(*args: Expr) -> Expr:
    10671067

    10681068
    def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
    10691069
    """Returns a struct with the given names and arguments pairs."""
    1070-
    name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
    1070+
    name_pair_exprs = [
    1071+
    [Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
    1072+
    for pair in name_pairs
    1073+
    ]
    10711074

    10721075
    # flatten
    10731076
    name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
    @@ -1424,7 +1427,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
    14241427
    nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
    14251428
    return Expr(
    14261429
    f.array_sort(
    1427-
    array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
    1430+
    array.expr,
    1431+
    Expr.literal(pa.scalar(desc, type=pa.string())).expr,
    1432+
    Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
    14281433
    )
    14291434
    )
    14301435

    python/datafusion/udf.py

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -229,6 +229,7 @@ def udaf(
    229229
    which this UDAF is used. The following examples are all valid.
    230230
    231231
    .. code-block:: python
    232+
    232233
    import pyarrow as pa
    233234
    import pyarrow.compute as pc
    234235

    python/tests/test_expr.py

    Lines changed: 12 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -85,14 +85,18 @@ def test_limit(test_ctx):
    8585

    8686
    plan = plan.to_variant()
    8787
    assert isinstance(plan, Limit)
    88-
    assert plan.skip() == 0
    88+
    # TODO: Upstream now has expressions for skip and fetch
    89+
    # REF: https://github.com/apache/datafusion/pull/12836
    90+
    # assert plan.skip() == 0
    8991

    9092
    df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
    9193
    plan = df.logical_plan()
    9294

    9395
    plan = plan.to_variant()
    9496
    assert isinstance(plan, Limit)
    95-
    assert plan.skip() == 5
    97+
    # TODO: Upstream now has expressions for skip and fetch
    98+
    # REF: https://github.com/apache/datafusion/pull/12836
    99+
    # assert plan.skip() == 5
    96100

    97101

    98102
    def test_aggregate_query(test_ctx):
    @@ -126,7 +130,10 @@ def test_relational_expr(test_ctx):
    126130
    ctx = SessionContext()
    127131

    128132
    batch = pa.RecordBatch.from_arrays(
    129-
    [pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
    133+
    [
    134+
    pa.array([1, 2, 3]),
    135+
    pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
    136+
    ],
    130137
    names=["a", "b"],
    131138
    )
    132139
    df = ctx.create_dataframe([[batch]], name="batch_array")
    @@ -141,7 +148,8 @@ def test_relational_expr(test_ctx):
    141148
    assert df.filter(col("b") == "beta").count() == 1
    142149
    assert df.filter(col("b") != "beta").count() == 2
    143150

    144-
    assert df.filter(col("a") == "beta").count() == 0
    151+
    with pytest.raises(Exception):
    152+
    df.filter(col("a") == "beta").count()
    145153

    146154

    147155
    def test_expr_to_variant():

    python/tests/test_functions.py

    Lines changed: 47 additions & 20 deletions
    +
    (
    Original file line numberDiff line numberDiff line change
    @@ -34,9 +34,9 @@ def df():
    3434
    # create a RecordBatch and a new DataFrame from it
    3535
    batch = pa.RecordBatch.from_arrays(
    3636
    [
    37-
    pa.array(["Hello", "World", "!"]),
    37+
    pa.array(["Hello", "World", "!"], type=pa.string_view()),
    3838
    pa.array([4, 5, 6]),
    39-
    pa.array(["hello ", " world ", " !"]),
    39+
    pa.array(["hello ", " world ", " !"], type=pa.string_view()),
    4040
    pa.array(
    4141
    [
    4242
    datetime(2022, 12, 31),
    @@ -88,16 +88,18 @@ def test_literal(df):
    8888
    assert len(result) == 1
    8989
    result = result[0]
    9090
    assert result.column(0) == pa.array([1] * 3)
    91-
    assert result.column(1) == pa.array(["1"] * 3)
    92-
    assert result.column(2) == pa.array(["OK"] * 3)
    91+
    assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
    92+
    assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
    9393
    assert result.column(3) == pa.array([3.14] * 3)
    9494
    assert result.column(4) == pa.array([True] * 3)
    9595
    assert result.column(5) == pa.array([b"hello world"] * 3)
    9696

    9797

    9898
    def test_lit_arith(df):
    9999
    """Test literals with arithmetic operations"""
    100-
    df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
    100+
    df = df.select(
    101+
    literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
    102+
    )
    101103
    result = df.collect()
    102104
    assert len(result) == 1
    103105
    result = result[0]
    @@ -600,21 +602,33 @@ def test_array_function_obj_tests(stmt, py_expr):
    600602
    f.ascii(column("a")),
    601603
    pa.array([72, 87, 33], type=pa.int32()),
    602604
    ), # H = 72; W = 87; ! = 33
    603-
    (f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
    604-
    (f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
    605+
    (
    606+
    f.bit_length(column("a").cast(pa.string())),
    607+
    pa.array([40, 40, 8], type=pa.int32()),
    608+
    ),
    609+
    (
    610+
    f.btrim(literal(" World ")),
    611+
    pa.array(["World", "World", "World"], type=pa.string_view()),
    612+
    ),
    605613
    (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
    606614
    (f.chr(literal(68)), pa.array(["D", "D", "D"])),
    607615
    (
    608616
    f.concat_ws("-", column("a"), literal("test")),
    609617
    pa.array(["Hello-test", "World-test", "!-test"]),
    610618
    ),
    611-
    (f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
    619
    620+
    f.concat(column("a").cast(pa.string()), literal("?")),
    621+
    pa.array(["Hello?", "World?", "!?"]),
    622+
    ),
    612623
    (f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
    613624
    (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
    614625
    (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
    615626
    (f.lower(column("a")), pa.array(["hello", "world", "!"])),
    616627
    (f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
    617-
    (f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
    628+
    (
    629+
    f.ltrim(column("c")),
    630+
    pa.array(["hello ", "world ", "!"], type=pa.string_view()),
    631+
    ),
    618632
    (
    619633
    f.md5(column("a")),
    620634
    pa.array(
    @@ -640,19 +654,25 @@ def test_array_function_obj_tests(stmt, py_expr):
    640654
    f.rpad(column("a"), literal(8)),
    641655
    pa.array(["Hello ", "World ", "! "]),
    642656
    ),
    643-
    (f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
    657+
    (
    658+
    f.rtrim(column("c")),
    659+
    pa.array(["hello", " world", " !"], type=pa.string_view()),
    660+
    ),
    644661
    (
    645662
    f.split_part(column("a"), literal("l"), literal(1)),
    646663
    pa.array(["He", "Wor", "!"]),
    647664
    ),
    648665
    (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
    649666
    (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
    650-
    (f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
    667+
    (
    668+
    f.substr(column("a"), literal(3)),
    669+
    pa.array(["llo", "rld", ""], type=pa.string_view()),
    670+
    ),
    651671
    (
    652672
    f.translate(column("a"), literal("or"), literal("ld")),
    653673
    pa.array(["Helll", "Wldld", "!"]),
    654674
    ),
    655-
    (f.trim(column("c")), pa.array(["hello", "world", "!"])),
    675+
    (f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
    656676
    (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
    657677
    (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
    658678
    (
    @@ -794,9 +814,9 @@ def test_temporal_functions(df):
    794814
    f.date_trunc(literal("month"), column("d")),
    795815
    f.datetrunc(literal("day"), column("d")),
    796816
    f.date_bin(
    797-
    literal("15 minutes"),
    817+
    literal("15 minutes").cast(pa.string()),
    798818
    column("d"),
    799-
    literal("2001-01-01 00:02:30"),
    819+
    literal("2001-01-01 00:02:30").cast(pa.string()),
    800820
    ),
    801821
    f.from_unixtime(literal(1673383974)),
    802822
    f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
    @@ -858,8 +878,8 @@ def test_case(df):
    858878
    result = df.collect()
    859879
    result = result[0]
    860880
    assert result.column(0) == pa.array([10, 8, 8])
    861-
    assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
    862-
    assert result.column(2) == pa.array(["Hola", "Mundo", None])
    881+
    assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
    882+
    assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())
    863883

    864884

    865885
    def test_when_with_no_base(df):
    @@ -877,8 +897,10 @@ def test_when_with_no_base(df):
    877897
    result = df.collect()
    878898
    result = result[0]
    879899
    assert result.column(0) == pa.array([4, 5, 6])
    880-
    assert result.column(1) == pa.array(["too small", "just right", "too big"])
    881-
    assert result.column(2) == pa.array(["Hello", None, None])
    900+
    assert result.column(1) == pa.array(
    901+
    ["too small", "just right", "too big"], type=pa.string_view()
    902+
    )
    903+
    assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())
    882904

    883905

    884906
    def test_regr_funcs_sql(df):
    @@ -1021,8 +1043,13 @@ def test_regr_funcs_df(func, expected):
    10211043

    10221044
    def test_binary_string_functions(df):
    10231045
    df = df.select(
    1024-
    f.encode(column("a"), literal("base64")),
    1025-
    f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
    1046+
    f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
    1047+
    f.decode(
    1048+
    f.encode(
    1049+
    column("a").cast(pa.string()), literal("base64").cast(pa.string())
    1050+
    ),
    1051+
    literal("base64").cast(pa.string()),
    1052+
    ),
    10261053
    )
    10271054
    result = df.collect()
    10281055
    assert len(result) == 1

    python/tests/test_imports.py

    Lines changed: 0 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -46,7 +46,6 @@
    4646
    Join,
    4747
    JoinType,
    4848
    JoinConstraint,
    49-
    CrossJoin,
    5049
    Union,
    5150
    Like,
    5251
    ILike,
    @@ -129,7 +128,6 @@ def test_class_module_is_datafusion():
    129128
    Join,
    130129
    JoinType,
    131130
    JoinConstraint,
    132-
    CrossJoin,
    133131
    Union,
    134132
    Like,
    135133
    ILike,

    python/tests/test_sql.py

    Lines changed: 7 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -468,6 +468,13 @@ def test_simple_select(ctx, tmp_path, arr):
    468468
    batches = ctx.sql("SELECT a AS tt FROM t").collect()
    469469
    result = batches[0].column(0)
    470470

    471+
    # In DF 43.0.0 we now default to having BinaryView and StringView
    472+
    # so the array that is saved to the parquet is slightly different
    473+
    # than the array read. Convert to values for comparison.
    474+
    if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray):
    475+
    arr = arr.tolist()
    476+
    result = result.tolist()
    477+
    471478
    np.testing.assert_equal(result, arr)
    472479

    473480

    src/context.rs

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -287,7 +287,7 @@ impl PySessionContext {
    287287
    } else {
    288288
    RuntimeConfig::default()
    289289
    };
    290-
    let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
    290+
    let runtime = Arc::new(RuntimeEnv::try_new(runtime_config)?);
    291291
    let session_state = SessionStateBuilder::new()
    292292
    .with_config(config)
    293293
    .with_runtime_env(runtime)

    src/dataframe.rs

    Lines changed: 6 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -402,7 +402,9 @@ impl PyDataFrame {
    402402

    403403
    #[pyo3(signature = (column, preserve_nulls=true))]
    404404
    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult<Self> {
    405-
    let unnest_options = UnnestOptions { preserve_nulls };
    405+
    // TODO: expose RecursionUnnestOptions
    406+
    // REF: https://github.com/apache/datafusion/pull/11577
    407+
    let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
    406408
    let df = self
    407409
    .df
    408410
    .as_ref()
    @@ -413,7 +415,9 @@ impl PyDataFrame {
    413415

    414416
    #[pyo3(signature = (columns, preserve_nulls=true))]
    415417
    fn unnest_columns(&self, columns: Vec<String>, preserve_nulls: bool) -> PyResult<Self> {
    416-
    let unnest_options = UnnestOptions { preserve_nulls };
    418+
    // TODO: expose RecursionUnnestOptions
    419+
    // REF: https://github.com/apache/datafusion/pull/11577
    420+
    let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
    417421
    let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
    418422
    let df = self
    419423
    .df

    src/expr.rs

    Lines changed: 0 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -65,7 +65,6 @@ pub mod column;
    6565
    pub mod conditional_expr;
    6666
    pub mod create_memory_table;
    6767
    pub mod create_view;
    68-
    pub mod cross_join;
    6968
    pub mod distinct;
    7069
    pub mod drop_table;
    7170
    pub mod empty_relation;
    @@ -775,7 +774,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
    775774
    m.add_class::<join::PyJoin>()?;
    776775
    m.add_class::<join::PyJoinType>()?;
    777776
    m.add_class::<join::PyJoinConstraint>()?;
    778-
    m.add_class::<cross_join::PyCrossJoin>()?;
    779777
    m.add_class::<union::PyUnion>()?;
    780778
    m.add_class::<unnest::PyUnnest>()?;
    781779
    m.add_class::<unnest_expr::PyUnnestExpr>()?;

    0 commit comments

    Comments
     (0)
    0