8000 Expose unnest feature (#641) · 3ok/datafusion-python@7366f89 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7366f89

Browse files
authored
Expose unnest feature (apache#641)
* Expose unnest feature * Update dataframe operation name to match rust implementation
1 parent 84415dd commit 7366f89

File tree

5 files changed

+135
-0
lines changed

5 files changed

+135
-0
lines changed

datafusion/tests/test_dataframe.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,20 @@ def struct_df():
6262
return ctx.create_dataframe([[batch]])
6363

6464

65+
@pytest.fixture
66+
def nested_df():
67+
ctx = SessionContext()
68+
69+
# create a RecordBatch and a new DataFrame from it
70+
# Intentionally make each array of different length
71+
batch = pa.RecordBatch.from_arrays(
72+
[pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])],
73+
names=["a", "b"],
74+
)
75+
76+
return ctx.create_dataframe([[batch]])
77+
78+
6579
@pytest.fixture
6680
def aggregate_df():
6781
ctx = SessionContext()
@@ -160,6 +174,26 @@ def test_with_column_renamed(df):
160174
assert result.schema.field(2).name == "sum"
161175

162176

177+
def test_unnest(nested_df):
178+
nested_df = nested_df.unnest_column("a")
179+
180+
# execute and collect the first (and only) batch
181+
result = nested_df.collect()[0]
182+
183+
assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None])
184+
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10])
185+
186+
187+
def test_unnest_without_nulls(nested_df):
188+
nested_df = nested_df.unnest_column("a", preserve_nulls=False)
189+
190+
# execute and collect the first (and only) batch
191+
result = nested_df.collect()[0]
192+
193+
assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6])
194+
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
195+
196+
163197
def test_udf(df):
164198
# is_null is a pa function over arrays
165199
is_null = udf(

src/dataframe.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion::execution::SendableRecordBatchStream;
2525
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
2626
use datafusion::parquet::file::properties::WriterProperties;
2727
use datafusion::prelude::*;
28+
use datafusion_common::UnnestOptions;
2829
use pyo3::exceptions::{PyTypeError, PyValueError};
2930
use pyo3::prelude::*;
3031
use pyo3::types::PyTuple;
@@ -293,6 +294,17 @@ impl PyDataFrame {
293294
Ok(Self::new(new_df))
294295
}
295296

297+
#[pyo3(signature = (column, preserve_nulls=true))]
298+
fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult<Self> {
299+
let unnest_options = UnnestOptions { preserve_nulls };
300+
let df = self
301+
.df
302+
.as_ref()
303+
.clone()
304+
.unnest_column_with_options(column, unnest_options)?;
305+
Ok(Self::new(df))
306+
}
307+
296308
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
297309
fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
298310
let new_df = self

src/expr.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ pub mod subquery;
8989
pub mod subquery_alias;
9090
pub mod table_scan;
9191
pub mod union;
92+
pub mod unnest;
9293
pub mod window;
9394

9495
/// A PyExpr that can be used on a DataFrame
@@ -684,6 +685,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
684685
m.add_class::<join::PyJoinConstraint>()?;
685686
m.add_class::<cross_join::PyCrossJoin>()?;
686687
m.add_class::<union::PyUnion>()?;
688+
m.add_class::<unnest::PyUnnest>()?;
687689
m.add_class::<extension::PyExtension>()?;
688690
m.add_class::<filter::PyFilter>()?;
689691
m.add_class::<projection::PyProjection>()?;

src/expr/unnest.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_expr::logical_plan::Unnest;
19+
use pyo3::prelude::*;
20+
use std::fmt::{self, Display, Formatter};
21+
22+
use crate::common::df_schema::PyDFSchema;
23+
use crate::expr::logical_node::LogicalNode;
24+
use crate::sql::logical::PyLogicalPlan;
25+
26+
#[pyclass(name = "Unnest", module = "datafusion.expr", subclass)]
27+
#[derive(Clone)]
28+
pub struct PyUnnest {
29+
unnest_: Unnest,
30+
}
31+
32+
impl From<Unnest> for PyUnnest {
33+
fn from(unnest_: Unnest) -> PyUnnest {
34+
PyUnnest { unnest_ }
35+
}
36+
}
37+
38+
impl From<PyUnnest> for Unnest {
39+
fn from(unnest_: PyUnnest) -> Self {
40+
unnest_.unnest_
41+
}
42+
}
43+
44+
impl Display for PyUnnest {
45+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
46+
write!(
47+
f,
48+
"Unnest
49+
Inputs: {:?}
50+
Schema: {:?}",
51+
&self.unnest_.input, &self.unnest_.schema,
52+
)
53+
}
54+
}
55+
56+
#[pymethods]
57+
impl PyUnnest {
58+
/// Retrieves the input `LogicalPlan` to this `Unnest` node
59+
fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
60+
Ok(Self::inputs(self))
61+
}
62+
63+
/// Resulting Schema for this `Unnest` node instance
64+
fn schema(&self) -> PyResult<PyDFSchema> {
65+
Ok(self.unnest_.schema.as_ref().clone().into())
66+
}
67+
68+
fn __repr__(&self) -> PyResult<String> {
69+
Ok(format!("Unnest({})", self))
70+
}
71+
72+
fn __name__(&self) -> PyResult<String> {
73+
Ok("Unnest".to_string())
74+
}
75+
}
76+
77+
impl LogicalNode for PyUnnest {
78+
fn inputs(&self) -> Vec<PyLogicalPlan> {
79+
vec![PyLogicalPlan::from((*self.unnest_.input).clone())]
80+
}
81+
82+
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
83+
Ok(self.clone().into_py(py))
84+
}
85+
}

src/sql/logical.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::expr::sort::PySort;
3333
use crate::expr::subquery::PySubquery;
3434
use crate::expr::subquery_alias::PySubqueryAlias;
3535
use crate::expr::table_scan::PyTableScan;
36+
use crate::expr::unnest::PyUnnest;
3637
use crate::expr::window::PyWindow;
3738
use datafusion_expr::LogicalPlan;
3839
use pyo3::prelude::*;
@@ -78,6 +79,7 @@ impl PyLogicalPlan {
7879
LogicalPlan::TableScan(plan) => PyTableScan 60D5 ::from(plan.clone()).to_variant(py),
7980
LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py),
8081
LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py),
82+
LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py),
8183
LogicalPlan::Window(plan) => PyWindow::from(plan.clone()).to_variant(py),
8284
other => Err(py_unsupported_variant_err(format!(
8385
"Cannot convert this plan to a LogicalNode: {:?}",

0 commit comments

Comments
 (0)
0