diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index b7f4ef4ad..5a6c68506 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -35,6 +35,7 @@ Expr, Projection, TableScan, + Aggregate, Sort, ) @@ -56,7 +57,7 @@ def test_class_module_is_datafusion(): ]: assert klass.__module__ == "datafusion" - for klass in [Expr, Projection, TableScan, Sort]: + for klass in [Expr, Projection, TableScan, Aggregate, Sort]: assert klass.__module__ == "datafusion.expr" for klass in [DFField, DFSchema]: diff --git a/src/expr.rs b/src/expr.rs index 68534bcba..15359d400 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -24,6 +24,7 @@ use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; use datafusion::scalar::ScalarValue; +pub mod aggregate; pub mod logical_node; pub mod projection; pub mod sort; @@ -144,6 +145,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs new file mode 100644 index 000000000..98d1f554b --- /dev/null +++ b/src/expr/aggregate.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::DataFusionError; +use datafusion_expr::logical_plan::Aggregate; +use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; + +use crate::common::df_schema::PyDFSchema; +use crate::expr::logical_node::LogicalNode; +use crate::expr::PyExpr; +use crate::sql::logical::PyLogicalPlan; + +#[pyclass(name = "Aggregate", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyAggregate { + aggregate: Aggregate, +} + +impl From for PyAggregate { + fn from(aggregate: Aggregate) -> PyAggregate { + PyAggregate { aggregate } + } +} + +impl TryFrom for Aggregate { + type Error = DataFusionError; + + fn try_from(agg: PyAggregate) -> Result { + Ok(agg.aggregate) + } +} + +impl Display for PyAggregate { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Aggregate + \nGroupBy(s): {:?} + \nAggregates(s): {:?} + \nInput: {:?} + \nProjected Schema: {:?}", + &self.aggregate.group_expr, + &self.aggregate.aggr_expr, + self.aggregate.input, + self.aggregate.schema + ) + } +} + +#[pymethods] +impl PyAggregate { + /// Retrieves the grouping expressions for this `Aggregate` + fn group_by_exprs(&self) -> PyResult> { + Ok(self + .aggregate + .group_expr + .iter() + .map(|e| PyExpr::from(e.clone())) + .collect()) + } + + /// Retrieves the aggregate expressions for this `Aggregate` + fn aggregate_exprs(&self) -> PyResult> { + Ok(self + .aggregate + .aggr_expr + .iter() + .map(|e| PyExpr::from(e.clone())) + .collect()) + } + + // Retrieves the input `LogicalPlan` to this `Aggregate` node + fn input(&self) -> PyLogicalPlan { + PyLogicalPlan::from((*self.aggregate.input).clone()) + } + + // Resulting Schema for this `Aggregate` node instance + fn schema(&self) -> PyDFSchema { + (*self.aggregate.schema).clone().into() + } + + fn __repr__(&self) -> PyResult { + Ok(format!("Aggregate({})", self)) + } +} + +impl LogicalNode for PyAggregate { + fn input(&self) -> Vec { + vec![PyLogicalPlan::from((*self.aggregate.input).clone())] + } +}