8000 feat: expose `join_on` (#914) · kylebarron/datafusion-python@b4b03fe · GitHub
[go: up one dir, main page]

Skip to content

Commit b4b03fe

Browse files
authored
feat: expose join_on (apache#914)
* feat: expose join_on method * test: improve join_on case
1 parent 72f2743 commit b4b03fe

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

python/datafusion/dataframe.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, List, TYPE_CHECKING
24+
from typing import Any, List, TYPE_CHECKING, Literal
2525
from datafusion.record_batch import RecordBatchStream
2626
from typing_extensions import deprecated
2727
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -304,6 +304,29 @@ def join(
304304
"""
305305
return DataFrame(self.df.join(right.df, join_keys, how))
306306

307+
def join_on(
308+
self,
309+
right: DataFrame,
310+
*on_exprs: Expr,
311+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
312+
) -> DataFrame:
313+
"""Join two :py:class:`DataFrame`using the specified expressions.
314+
315+
On expressions are used to support in-equality predicates. Equality
316+
predicates are correctly optimized
317+
318+
Args:
319+
right: Other DataFrame to join with.
320+
on_exprs: single or multiple (in)-equality predicates.
321+
how: Type of join to perform. Supported types are "inner", "left",
322+
"right", "full", "semi", "anti".
323+
324+
Returns:
325+
DataFrame after join.
326+
"""
327+
exprs = [expr.expr for expr in on_exprs]
328+
return DataFrame(self.df.join_on(right.df, exprs, how))
329+
307330
def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
308331
"""Return a DataFrame with the explanation of its plan so far.
309332

python/tests/test_dataframe.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,42 @@ def test_join():
270270
assert table.to_pydict() == expected
271271

272272

273+
def test_join_on():
274+
ctx = SessionContext()
275+
276+
batch = pa.RecordBatch.from_arrays(
277+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
278+
names=["a", "b"],
279+
)
280+
df = ctx.create_dataframe([[batch]], "l")
281+
282+
batch = pa.RecordBatch.from_arrays(
283+
[pa.array([1, 2]), pa.array([-8, 10])],
284+
names=["a", "c"],
285+
)
286+
df1 = ctx.create_dataframe([[batch]], "r")
287+
288+
df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
289+
df2.show()
290+
df2 = df2.sort(column("l.a"))
291+
table = pa.Table.from_batches(df2.collect())
292+
293+
expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]}
294+
assert table.to_pydict() == expected
295+
296+
df3 = df.join_on(
297+
df1,
298+
column("l.a").__eq__(column("r.a")),
299+
column("l.a").__lt__(column("r.c")),
300+
how="inner",
301+
)
302+
df3.show()
303+
df3 = df3.sort(column("l.a"))
304+
table = pa.Table.from_batches(df3.collect())
305+
expected = {"a": [2], "c": [10], "b": [5]}
306+
assert table.to_pydict() == expected
307+
308+
273309
def test_distinct():
274310
ctx = SessionContext()
275311

src/dataframe.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,31 @@ impl PyDataFrame {
300300
Ok(Self::new(df))
301301
}
302302

303+
fn join_on(&self, right: PyDataFrame, on_exprs: Vec<PyExpr>, how: &str) -> PyResult<Self> {
304+
let join_type = match how {
305+
"inner" => JoinType::Inner,
306+
"left" => JoinType::Left,
307+
"right" => JoinType::Right,
308+
"full" => JoinType::Full,
309+
"semi" => JoinType::LeftSemi,
310+
"anti" => JoinType::LeftAnti,
311+
how => {
312+
return Err(DataFusionError::Common(format!(
313+
"The join type {how} does not exist or is not implemented"
314+
))
315+
.into());
316+
}
317+
};
318+
let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
319+
320+
let df = self
321+
.df
322+
.as_ref()
323+
.clone()
324+
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
325+
Ok(Self::new(df))
326+
}
327+
303328
/// Print the query plan
304329
#[pyo3(signature = (verbose=false, analyze=false))]
305330
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {

0 commit comments

Comments
 (0)
0