8000 Add MergeQueryFiles config opt · sqlc-dev/sqlc-gen-python@b821701 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit b821701

Browse files
committed
Add MergeQueryFiles config opt
1 parent 7b4b355 commit b821701

File tree

3 files changed

+13
-4
lines changed
Collapse file tree

3 files changed

+13
-4
lines changed

internal/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ type Config struct {
1111
// When a query uses a table with RLS enforced fields, it will be required to
1212
// parametrized those fields. Associate tables are not covered!
1313
RLSEnforcedFields []string `json:"rls_enforced_fields"`
14+
// Merge queries defined in different files into one output queries.py file
15+
MergeQueryFiles bool `json:"merge_query_files"`
1416
}
1517

1618
const MODELS_FILENAME = "db_models"

internal/gen.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,9 @@ type pyTmplCtx struct {
11131113
}
11141114

11151115
func (t *pyTmplCtx) OutputQuery(sourceName string) bool {
1116+
if t.C.MergeQueryFiles {
1117+
return true
1118+
}
11161119
return t.SourceName == sourceName
11171120
}
11181121

@@ -1156,8 +1159,12 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
11561159
output[MODELS_FILENAME+".py"] = string(result.Python)
11571160

11581161
files := map[string]struct{}{}
1159-
for _, q := range queries {
1160-
files[q.SourceName] = struct{}{}
1162+
if i.C.MergeQueryFiles {
1163+
files["db_queries.sql"] = struct{}{}
1164+
} else {
1165+
for _, q := range queries {
1166+
files[q.SourceName] = struct{}{}
1167+
}
11611168
}
11621169

11631170
for source := range files {

internal/imports.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func (i *importer) modelImports() []string {
113113
func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map[string]importSpec) {
114114
queryUses := func(name string) bool {
115115
for _, q := range i.Queries {
116-
if q.SourceName != fileName {
116+
if !i.C.MergeQueryFiles && q.SourceName != fileName {
117117
continue
118118
}
119119
if queryValueUses(name, q.Ret) {
@@ -144,7 +144,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
144144
}
145145

146146
for _, q := range i.Queries {
147-
if q.SourceName != fileName {
147+
if !i.C.MergeQueryFiles && q.SourceName != fileName {
148148
continue
149149
}
150150
if q.Cmd == ":one" {

0 commit comments

Comments
 (0)
0