16
16
// under the License.
17
17
18
18
use datafusion_common:: DataFusionError ;
19
+ use datafusion_expr:: expr:: { AggregateFunction , AggregateUDF , Alias } ;
19
20
use datafusion_expr:: logical_plan:: Aggregate ;
21
+ use datafusion_expr:: Expr ;
20
22
use pyo3:: prelude:: * ;
21
23
use std:: fmt:: { self , Display , Formatter } ;
22
24
23
25
use super :: logical_node:: LogicalNode ;
24
26
use crate :: common:: df_schema:: PyDFSchema ;
27
+ use crate :: errors:: py_type_err;
25
28
use crate :: expr:: PyExpr ;
26
29
use crate :: sql:: logical:: PyLogicalPlan ;
27
30
@@ -84,6 +87,24 @@ impl PyAggregate {
84
87
. collect ( ) )
85
88
}
86
89
90
+ /// Returns the inner Aggregate Expr(s)
91
+ pub fn agg_expressions ( & self ) -> PyResult < Vec < PyExpr > > {
92
+ Ok ( self
93
+ . aggregate
94
+ . aggr_expr
95
+ . iter ( )
96
+ . map ( |e| PyExpr :: from ( e. clone ( ) ) )
97
+ . collect ( ) )
98
+ }
99
+
100
+ pub fn agg_func_name ( & self , expr : PyExpr ) -> PyResult < String > {
101
+ Self :: _agg_func_name ( & expr. expr )
102
+ }
103
+
104
+ pub fn aggregation_arguments ( & self , expr : PyExpr ) -> PyResult < Vec < PyExpr > > {
105
+ self . _aggregation_arguments ( & expr. expr )
106
+ }
107
+
87
108
// Retrieves the input `LogicalPlan` to this `Aggregate` node
88
109
fn input ( & self ) -> PyResult < Vec < PyLogicalPlan > > {
89
110
Ok ( Self :: inputs ( self ) )
@@ -99,6 +120,34 @@ impl PyAggregate {
99
120
}
100
121
}
101
122
123
+ impl PyAggregate {
124
+ #[ allow( clippy:: only_used_in_recursion) ]
125
+ fn _aggregation_arguments ( & self , expr : & Expr ) -> PyResult < Vec < PyExpr > > {
126
+ match expr {
127
+ // TODO: This Alias logic seems to be returning some strange results that we should investigate
128
+ Expr :: Alias ( Alias { expr, .. } ) => self . _aggregation_arguments ( expr. as_ref ( ) ) ,
129
+ Expr :: AggregateFunction ( AggregateFunction { fun : _, args, .. } )
130
+ | Expr :: AggregateUDF ( AggregateUDF { fun : _, args, .. } ) => {
131
+ Ok ( args. iter ( ) . map ( |e| PyExpr :: from ( e. clone ( ) ) ) . collect ( ) )
132
+ }
133
+ _ => Err ( py_type_err (
134
+ "Encountered a non Aggregate type in aggregation_arguments" ,
135
+ ) ) ,
136
+ }
137
+ }
138
+
139
+ fn _agg_func_name ( expr : & Expr ) -> PyResult < String > {
140
+ match expr {
141
+ Expr :: Alias ( Alias { expr, .. } ) => Self :: _agg_func_name ( expr. as_ref ( ) ) ,
142
+ Expr :: AggregateFunction ( AggregateFunction { fun, .. } ) => Ok ( fun. to_string ( ) ) ,
143
+ Expr :: AggregateUDF ( AggregateUDF { fun, .. } ) => Ok ( fun. name . clone ( ) ) ,
144
+ _ => Err ( py_type_err (
145
+ "Encountered a non Aggregate type in agg_func_name" ,
146
+ ) ) ,
147
+ }
148
+ }
149
+ }
150
+
102
151
impl LogicalNode for PyAggregate {
103
152
fn inputs ( & self ) -> Vec < PyLogicalPlan > {
104
153
vec ! [ PyLogicalPlan :: from( ( * self . aggregate. input) . clone( ) ) ]
0 commit comments