8000 Refactor w/ isAlwaysReturningInsert() · sqlc-dev/sqlc-gen-python@41d4d6b · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 41d4d6b

Browse files
committed
Refactor w/ isAlwaysReturningInsert()
1 parent d1a001d commit 41d4d6b

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

internal/gen.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
889889
case ":one":
890890
fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...)
891891
f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow)))
892-
if !strings.Contains(strings.ToUpper(q.SQL), "WHERE ") &&
893-
(strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT")) {
892+
893+
if isAlwaysReturningInsert(q.SQL) {
894894
f.Returns = q.Ret.Annotation()
895895
} else {
896896
f.Body = append(f.Body, poet.Node(
@@ -1028,3 +1028,18 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
10281028

10291029
return &resp, nil
10301030
}
1031+
1032+
func isAlwaysReturningInsert(sql string) bool {
1033+
var hasInsert, hasWhere, hasReturning bool
1034+
for _, w := range strings.Fields(sql) {
1035+
switch strings.ToUpper(w) {
1036+
case "INSERT":
1037+
hasInsert = true
1038+
case "WHERE":
1039+
hasWhere = true
1040+
case "RETURNING":
1041+
hasReturning = true
1042+
}
1043+
}
1044+
return hasInsert && hasReturning && !hasWhere
1045+
}

0 commit comments

Comments
 (0)
0