From 41d4d6ba4ae36239b3892041625b4272cff3cac6 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sat, 14 Sep 2024 01:43:00 +0200 Subject: [PATCH] Refactor w/ isAlwaysReturningInsert() --- internal/gen.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index bd21d83..a5d36c9 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -889,8 +889,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { case ":one": fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow))) - if !strings.Contains(strings.ToUpper(q.SQL), "WHERE ") && - (strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT")) { + + if isAlwaysReturningInsert(q.SQL) { f.Returns = q.Ret.Annotation() } else { f.Body = append(f.Body, poet.Node( @@ -1028,3 +1028,18 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR return &resp, nil } + +func isAlwaysReturningInsert(sql string) bool { + var hasInsert, hasWhere, hasReturning bool + for _, w := range strings.Fields(sql) { + switch strings.ToUpper(w) { + case "INSERT": + hasInsert = true + case "WHERE": + hasWhere = true + case "RETURNING": + hasReturning = true + } + } + return hasInsert && hasReturning && !hasWhere +}