@@ -151,6 +151,26 @@ func (q Query) AddArgs(args *pyast.Arguments) {
151
151
}
152
152
}
153
153
154
+ func (q Query ) ArgNodes () []* pyast.Node {
155
+ args := []* pyast.Node {}
156
+ i := 1
157
+ for _ , a := range q .Args {
158
+ if a .isEmpty () {
159
+ continue
160
+ }
161
+ if a .IsStruct () {
162
+ for _ , f := range a .Struct .Fields {
163
+ args = append (args , typeRefNode (a .Name , f .Name ))
164
+ i ++
165
+ }
166
+ } else {
167
+ args = append (args , poet .Name (a .Name ))
168
+ i ++
169
+ }
170
+ }
171
+ return args
172
+ }
173
+
154
174
func (q Query ) ArgDictNode () * pyast.Node {
155
175
dict := & pyast.Dict {}
156
176
i := 1
@@ -612,11 +632,9 @@ func typeRefNode(base string, parts ...string) *pyast.Node {
612
632
return n
613
633
}
614
634
615
- func connMethodNode (method , name string , arg * pyast.Node ) * pyast.Node {
635
+ func connMethodNode (method , name string , params ... * pyast.Node ) * pyast.Node {
616
636
args := []* pyast.Node {poet .Name (name )}
617
- if arg != nil {
618
- args = append (args , arg )
619
- }
637
+ args = append (args , params ... )
620
638
return & pyast.Node {
621
639
Node : & pyast.Node_Call {
622
640
Call : & pyast.Call {
@@ -869,7 +887,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
869
887
870
888
switch q .Cmd {
871
889
case ":one" :
872
- fetchrow := connMethodNode ("fetchrow" , q .ConstantName , q .ArgDictNode () )
890
+ fetchrow := connMethodNode ("fetchrow" , q .ConstantName , q .ArgNodes () ... )
873
891
f .Body = append (f .Body ,
874
892
assignNode ("row" , poet .Await (fetchrow )),
875
893
poet .Node (
@@ -896,7 +914,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
896
914
)
897
915
f .Returns = subscriptNode ("Optional" , q .Ret .Annotation ())
898
916
case ":many" :
899
- cursor := connMethodNode ("cursor" , q .ConstantName , q .ArgDictNode () )
917
+ cursor := connMethodNode ("cursor" , q .ConstantName , q .ArgNodes () ... )
900
918
f .Body = append (f .Body ,
901
919
poet .Node (
902
920
& pyast.AsyncFor {
@@ -914,7 +932,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
914
932
)
915
933
f .Returns = subscriptNode ("AsyncIterator" , q .Ret .Annotation ())
916
934
case ":exec" :
917
- exec := connMethodNode ("execute" , q .ConstantName , q .ArgDictNode () )
935
+ exec := connMethodNode ("execute" , q .ConstantName , q .ArgNodes () ... )
918
936
f .Body = append (f .Body , poet .Await (exec ))
919
937
f .Returns = poet .Constant (nil )
920
938
default :
0 commit comments