From 245f052f535dc68bb8937cf3924fc31132b83cb0 Mon Sep 17 00:00:00 2001 From: Kyle Gray Date: Mon, 5 Aug 2024 12:55:51 -0700 Subject: [PATCH 1/8] build(deps): Upgrade to sqlc v1.27.0 --- examples/src/authors/models.py | 2 +- examples/src/authors/query.py | 2 +- examples/src/booktest/models.py | 2 +- examples/src/booktest/query.py | 2 +- examples/src/jets/models.py | 2 +- examples/src/jets/query-building.py | 2 +- examples/src/ondeck/city.py | 2 +- examples/src/ondeck/models.py | 2 +- examples/src/ondeck/venue.py | 2 +- internal/endtoend/testdata/emit_pydantic_models/db/models.py | 2 +- internal/endtoend/testdata/emit_pydantic_models/db/query.py | 2 +- internal/endtoend/testdata/exec_result/python/models.py | 2 +- internal/endtoend/testdata/exec_result/python/query.py | 2 +- internal/endtoend/testdata/exec_rows/python/models.py | 2 +- internal/endtoend/testdata/exec_rows/python/query.py | 2 +- .../testdata/inflection_exclude_table_names/python/models.py | 2 +- .../testdata/inflection_exclude_table_names/python/query.py | 2 +- .../testdata/query_parameter_limit_two/python/models.py | 2 +- .../endtoend/testdata/query_parameter_limit_two/python/query.py | 2 +- .../testdata/query_parameter_limit_undefined/python/models.py | 2 +- .../testdata/query_parameter_limit_undefined/python/query.py | 2 +- .../testdata/query_parameter_limit_zero/python/models.py | 2 +- .../testdata/query_parameter_limit_zero/python/query.py | 2 +- 23 files changed, 23 insertions(+), 23 deletions(-) diff --git a/examples/src/authors/models.py b/examples/src/authors/models.py index d3ade78..906e980 100644 --- a/examples/src/authors/models.py +++ b/examples/src/authors/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses from typing import Optional diff --git a/examples/src/authors/query.py b/examples/src/authors/query.py index 62ebde3..afa81ab 100644 --- a/examples/src/authors/query.py +++ b/examples/src/authors/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/examples/src/booktest/models.py b/examples/src/booktest/models.py index f6f3d31..d5dacb4 100644 --- a/examples/src/booktest/models.py +++ b/examples/src/booktest/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses import datetime import enum diff --git a/examples/src/booktest/query.py b/examples/src/booktest/query.py index a52394a..1720a02 100644 --- a/examples/src/booktest/query.py +++ b/examples/src/booktest/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import dataclasses import datetime diff --git a/examples/src/jets/models.py b/examples/src/jets/models.py index 6f42c41..efd4e2b 100644 --- a/examples/src/jets/models.py +++ b/examples/src/jets/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/examples/src/jets/query-building.py b/examples/src/jets/query-building.py index 14f0752..374e01e 100644 --- a/examples/src/jets/query-building.py +++ b/examples/src/jets/query-building.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query-building.sql from typing import AsyncIterator, Optional diff --git a/examples/src/ondeck/city.py b/examples/src/ondeck/city.py index 2fd7acf..b24a5d6 100644 --- a/examples/src/ondeck/city.py +++ b/examples/src/ondeck/city.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: city.sql from typing import AsyncIterator, Optional diff --git a/examples/src/ondeck/models.py b/examples/src/ondeck/models.py index 1babdff..bcc12f1 100644 --- a/examples/src/ondeck/models.py +++ b/examples/src/ondeck/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses import datetime import enum diff --git a/examples/src/ondeck/venue.py b/examples/src/ondeck/venue.py index f5bb1f0..f540076 100644 --- a/examples/src/ondeck/venue.py +++ b/examples/src/ondeck/venue.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: venue.sql import dataclasses from typing import AsyncIterator, List, Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/models.py b/internal/endtoend/testdata/emit_pydantic_models/db/models.py index 7b3457e..1e312fd 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/models.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import pydantic from typing import Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/query.py b/internal/endtoend/testdata/emit_pydantic_models/db/query.py index 67ce686..7739bd7 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/query.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/internal/endtoend/testdata/exec_result/python/models.py b/internal/endtoend/testdata/exec_result/python/models.py index 20c9e31..7a4ffd0 100644 --- a/internal/endtoend/testdata/exec_result/python/models.py +++ b/internal/endtoend/testdata/exec_result/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_result/python/query.py b/internal/endtoend/testdata/exec_result/python/query.py index 720ab18..5828377 100644 --- a/internal/endtoend/testdata/exec_result/python/query.py +++ b/internal/endtoend/testdata/exec_result/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_rows/python/models.py b/internal/endtoend/testdata/exec_rows/python/models.py index 20c9e31..7a4ffd0 100644 --- a/internal/endtoend/testdata/exec_rows/python/models.py +++ b/internal/endtoend/testdata/exec_rows/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_rows/python/query.py b/internal/endtoend/testdata/exec_rows/python/query.py index d1e117b..05c3094 100644 --- a/internal/endtoend/testdata/exec_rows/python/query.py +++ b/internal/endtoend/testdata/exec_rows/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py index 3702b2a..1b8c35a 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py index 9ff587f..e09b677 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql from typing import Optional diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py index 4813c11..edafe46 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py index 36cf1ba..48cf42a 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py index 794d992..3019fbd 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py index a787e5d..6c8f593 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py index 4813c11..edafe46 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py index 24de8f1..ee36b2a 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.26.0 +# sqlc v1.27.0 # source: query.sql import dataclasses From 4d2625e9b5eec426d8676b9619cb1d6d43e0f762 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 14:12:57 +0200 Subject: [PATCH 2/8] models.py -> db_models.py --- internal/gen.go | 4 ++-- internal/imports.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index a5d36c9..2990db5 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -998,8 +998,8 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output := map[string]string{} result := pyprint.Print(buildModelsTree(&tctx, i), pyprint.Options{}) - tctx.SourceName = "models.py" - output["models.py"] = string(result.Python) + tctx.SourceName = "db_models.py" + output["db_models.py"] = string(result.Python) files := map[string]struct{}{} for _, q := range queries { diff --git a/internal/imports.go b/internal/imports.go index 423b1a0..08bc710 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -69,7 +69,7 @@ func queryValueUses(name string, qv QueryValue) bool { } func (i *importer) Imports(fileName string) []string { - if fileName == "models.py" { + if fileName == "db_models.py" { return i.modelImports() } return i.queryImports(fileName) From 648631bd94c9a9d58ca2f841817fb5466382e7d4 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 18:32:17 +0200 Subject: [PATCH 3/8] revert to keyword-only args --- internal/printer/printer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/printer/printer.go b/internal/printer/printer.go index f56ff45..0660c6a 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -381,7 +381,7 @@ func (w *writer) printFunctionDef(fd *ast.FunctionDef, indent int32) { } } if len(fd.Args.KwOnlyArgs) > 0 { - w.print(", ") + w.print(", *, ") for i, arg := range fd.Args.KwOnlyArgs { w.printArg(arg, indent) if i != len(fd.Args.KwOnlyArgs)-1 { From abec3c8744d00b1527c3a40134b5df47ae6b5d39 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 18:40:52 +0200 Subject: [PATCH 4/8] Default to None for optional func args --- internal/gen.go | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index 2990db5..08a27b3 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -38,12 +38,15 @@ type pyType struct { IsNull bool } -func (t pyType) Annotation() *pyast.Node { +func (t pyType) Annotation(isFuncSignature bool) *pyast.Node { ann := poet.Name(t.InnerType) if t.IsArray { ann = subscriptNode("List", ann) } - if t.IsNull { + if t.IsNull && isFuncSignature { + ann = optionalKeywordNode("Optional", ann) + } + if t.IsNull && !isFuncSignature { ann = subscriptNode("Optional", ann) } return ann @@ -69,9 +72,10 @@ type QueryValue struct { Typ pyType } +// Annotation in function signature func (v QueryValue) Annotation() *pyast.Node { if v.Typ != (pyType{}) { - return v.Typ.Annotation() + return v.Typ.Annotation(true) } if v.Struct != nil { if v.Emit { @@ -143,12 +147,21 @@ func (q Query) AddArgs(args *pyast.Arguments) { }) return } + var optionalArgs []*pyast.Arg for _, a := range q.Args { + if a.Typ.IsNull { + optionalArgs = append(optionalArgs, &pyast.Arg{ + Arg: a.Name, + Annotation: a.Annotation(), + }) + continue + } args.KwOnlyArgs = append(args.KwOnlyArgs, &pyast.Arg{ Arg: a.Name, Annotation: a.Annotation(), }) } + args.KwOnlyArgs = append(args.KwOnlyArgs, optionalArgs...) } func (q Query) ArgNodes() []*pyast.Node { @@ -577,6 +590,31 @@ func subscriptNode(value string, slice *pyast.Node) *pyast.Node { } } +func optionalKeywordNode(value string, slice *pyast.Node) *pyast.Node { + v := &pyast.Node{ + Node: &pyast.Node_Subscript{ + Subscript: &pyast.Subscript{ + Value: &pyast.Name{Id: value}, + Slice: slice, + }, + }, + } + return &pyast.Node{ + Node: &pyast.Node_Keyword{ + Keyword: &pyast.Keyword{ + Arg: string(pyprint.Print(v, pyprint.Options{}).Python), + Value: &pyast.Node{ + Node: &pyast.Node_Constant{ + Constant: &pyast.Constant{ + Value: &pyast.Constant_None{None: true}, + }, + }, + }, + }, + }, + } +} + func dataclassNode(name string) *pyast.ClassDef { return &pyast.ClassDef{ Name: name, @@ -617,7 +655,7 @@ func fieldNode(f Field) *pyast.Node { Node: &pyast.Node_AnnAssign{ AnnAssign: &pyast.AnnAssign{ Target: &pyast.Name{Id: f.Name}, - Annotation: f.Type.Annotation(), + Annotation: f.Type.Annotation(false), Comment: f.Comment, }, }, From b376daadd34c421d904c6274897d0f652661ae8f Mon Sep 17 00:00:00 2001 From: simo7 Date: Sat, 5 Oct 2024 13:22:48 +0200 Subject: [PATCH 5/8] RLS enforced fields --- internal/config.go | 3 +++ internal/gen.go | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/config.go b/internal/config.go index 899009e..032eb01 100644 --- a/internal/config.go +++ b/internal/config.go @@ -8,4 +8,7 @@ type Config struct { QueryParameterLimit *int32 `json:"query_parameter_limit"` InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` TablePrefix string `json:"table_prefix"` + // When a query uses a table with RLS enforced fields, it will be required to + // parametrized those fields. Associate tables are not covered! + RLSEnforcedFields []string `json:"rls_enforced_fields"` } diff --git a/internal/gen.go b/internal/gen.go index 08a27b3..5bd1683 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -388,6 +388,20 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum } func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { + rlsFieldsByTable := make(map[string][]string) // TODO + if len(conf.RLSEnforcedFields) > 0 { + for i := range structs { + tableName := structs[i].Table.Name + for _, f := range structs[i].Fields { + for _, enforced := range conf.RLSEnforcedFields { + if f.Name == enforced { + rlsFieldsByTable[tableName] = append(rlsFieldsByTable[tableName], f.Name) + } + } + } + } + } + qs := make([]Query, 0, len(req.Queries)) for _, query := range req.Queries { if query.Name == "" { @@ -419,9 +433,20 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ if qpl < 0 { return nil, errors.New("invalid query parameter limit") } + enforcedFields := make(map[string]bool) + for _, c := range query.Columns { + if fields, ok := rlsFieldsByTable[c.GetTable().GetName()]; ok { + for _, f := range fields { + enforcedFields[f] = false + } + } + } if len(query.Params) > qpl || qpl == 0 { var cols []pyColumn for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } cols = append(cols, pyColumn{ id: p.Number, Column: p.Column, @@ -435,6 +460,9 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } args = append(args, QueryValue{ Name: paramName(p), Typ: makePyType(req, p.Column), @@ -442,7 +470,11 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } gq.Args = args } - + for field, is_enforced := range enforcedFields { + if !is_enforced { + return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name) + } + } if len(query.Columns) == 1 { c := query.Columns[0] gq.Ret = QueryValue{ From 7b4b3550150ceb3ef16f5866cec9a81fed392fac Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 01:58:29 +0200 Subject: [PATCH 6/8] Implement sqlc.embed --- internal/config.go | 2 + internal/gen.go | 116 ++++++++++++++++++++++++++++++++++++++------ internal/imports.go | 4 +- 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/internal/config.go b/internal/config.go index 032eb01..706c978 100644 --- a/internal/config.go +++ b/internal/config.go @@ -12,3 +12,5 @@ type Config struct { // parametrized those fields. Associate tables are not covered! RLSEnforcedFields []string `json:"rls_enforced_fields"` } + +const MODELS_FILENAME = "db_models" diff --git a/internal/gen.go b/internal/gen.go index 5bd1683..ca82758 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -56,6 +56,8 @@ type Field struct { Name string Type pyType Comment string + // EmbedFields contains the embedded fields that require scanning. + EmbedFields []Field } type Struct struct { @@ -81,7 +83,7 @@ func (v QueryValue) Annotation() *pyast.Node { if v.Emit { return poet.Name(v.Struct.Name) } else { - return typeRefNode("models", v.Struct.Name) + return typeRefNode(MODELS_FILENAME, v.Struct.Name) } } panic("no type for QueryValue: " + v.Name) @@ -109,14 +111,41 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { call := &pyast.Call{ Func: v.Annotation(), } - for i, f := range v.Struct.Fields { - call.Keywords = append(call.Keywords, &pyast.Keyword{ - Arg: f.Name, - Value: subscriptNode( + var idx int + for _, f := range v.Struct.Fields { + var val *pyast.Node + if len(f.EmbedFields) > 0 { + var embedFields []*pyast.Keyword + for _, embed := range f.EmbedFields { + embedFields = append(embedFields, &pyast.Keyword{ + Arg: embed.Name, + Value: subscriptNode( + rowVar, + constantInt(idx), + ), + }) + idx++ + } + val = &pyast.Node{ + Node: &pyast.Node_Call{ + Call: &pyast.Call{ + Func: f.Type.Annotation(false), + Keywords: embedFields, + }, + }, + } + } else { + val = subscriptNode( rowVar, - constantInt(i), - ), + constantInt(idx), + ) + idx++ + } + call.Keywords = append(call.Keywords, &pyast.Keyword{ + Arg: f.Name, + Value: val, }) + } return &pyast.Node{ Node: &pyast.Node_Call{ @@ -355,6 +384,46 @@ func paramName(p *plugin.Parameter) string { type pyColumn struct { id int32 *plugin.Column + embed *pyEmbed +} + +type pyEmbed struct { + modelType string + modelName string + fields []Field +} + +// look through all the structs and attempt to find a matching one to embed +// We need the name of the struct and its field names. +func newPyEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed { + if embed == nil { + return nil + } + + for _, s := range structs { + embedSchema := defaultSchema + if embed.Schema != "" { + embedSchema = embed.Schema + } + + // compare the other attributes + if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { + continue + } + + fields := make([]Field, len(s.Fields)) + for i, f := range s.Fields { + fields[i] = f + } + + return &pyEmbed{ + modelType: s.Name, + modelName: s.Name, + fields: fields, + } + } + + return nil } func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct { @@ -366,6 +435,12 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum for i, c := range columns { colName := columnName(c.Column, i) fieldName := colName + + // override col with expected model name + if c.embed != nil { + colName = c.embed.modelName + } + // Track suffixes by the ID of the column, so that columns referring to // the same numbered parameter can be reused. var suffix int32 @@ -378,17 +453,25 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum if suffix > 0 { fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } - gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: makePyType(req, c.Column), - }) + f := Field{Name: fieldName} + if c.embed == nil { + f.Type = makePyType(req, c.Column) + } else { + f.Type = pyType{ + InnerType: MODELS_FILENAME + "." + c.embed.modelType, + IsArray: c.IsArray, + IsNull: false, + } + f.EmbedFields = c.embed.fields + } + gs.Fields = append(gs.Fields, f) seen[colName]++ } return &gs } func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { - rlsFieldsByTable := make(map[string][]string) // TODO + rlsFieldsByTable := make(map[string][]string) if len(conf.RLSEnforcedFields) > 0 { for i := range structs { tableName := structs[i].Table.Name @@ -475,7 +558,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name) } } - if len(query.Columns) == 1 { + if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { c := query.Columns[0] gq.Ret = QueryValue{ Name: columnName(c, 0), @@ -515,6 +598,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ columns = append(columns, pyColumn{ id: int32(i), Column: c, + embed: newPyEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), }) } gs = columnsToStruct(req, query.Name+"Row", columns) @@ -893,7 +977,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ImportFrom: &pyast.ImportFrom{ Module: ctx.C.Package, Names: []*pyast.Node{ - poet.Alias("models"), + poet.Alias(MODELS_FILENAME), }, }, }, @@ -1068,8 +1152,8 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output := map[string]string{} result := pyprint.Print(buildModelsTree(&tctx, i), pyprint.Options{}) - tctx.SourceName = "db_models.py" - output["db_models.py"] = string(result.Python) + tctx.SourceName = MODELS_FILENAME + ".py" + output[MODELS_FILENAME+".py"] = string(result.Python) files := map[string]struct{}{} for _, q := range queries { diff --git a/internal/imports.go b/internal/imports.go index 08bc710..b892bd8 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -69,7 +69,7 @@ func queryValueUses(name string, qv QueryValue) bool { } func (i *importer) Imports(fileName string) []string { - if fileName == "db_models.py" { + if fileName == MODELS_FILENAME+".py" { return i.modelImports() } return i.queryImports(fileName) @@ -165,7 +165,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map func (i *importer) queryImports(fileName string) []string { std, pkg := i.queryImportSpecs(fileName) - modelImportStr := fmt.Sprintf("from %s import models", i.C.Package) + modelImportStr := fmt.Sprintf("from %s import %s", i.C.Package, MODELS_FILENAME) importLines := []string{ buildImportBlock(std), From b8217016eb5966cea358a592c41bbf3244e2c0b0 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 03:01:23 +0200 Subject: [PATCH 7/8] Add MergeQueryFiles config opt --- internal/config.go | 2 ++ internal/gen.go | 11 +++++++++-- internal/imports.go | 4 ++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/internal/config.go b/internal/config.go index 706c978..ccf5a48 100644 --- a/internal/config.go +++ b/internal/config.go @@ -11,6 +11,8 @@ type Config struct { // When a query uses a table with RLS enforced fields, it will be required to // parametrized those fields. Associate tables are not covered! RLSEnforcedFields []string `json:"rls_enforced_fields"` + // Merge queries defined in different files into one output queries.py file + MergeQueryFiles bool `json:"merge_query_files"` } const MODELS_FILENAME = "db_models" diff --git a/internal/gen.go b/internal/gen.go index ca82758..5877a9c 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -1113,6 +1113,9 @@ type pyTmplCtx struct { } func (t *pyTmplCtx) OutputQuery(sourceName string) bool { + if t.C.MergeQueryFiles { + return true + } return t.SourceName == sourceName } @@ -1156,8 +1159,12 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output[MODELS_FILENAME+".py"] = string(result.Python) files := map[string]struct{}{} - for _, q := range queries { - files[q.SourceName] = struct{}{} + if i.C.MergeQueryFiles { + files["db_queries.sql"] = struct{}{} + } else { + for _, q := range queries { + files[q.SourceName] = struct{}{} + } } for source := range files { diff --git a/internal/imports.go b/internal/imports.go index b892bd8..f5011b7 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -113,7 +113,7 @@ func (i *importer) modelImports() []string { func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map[string]importSpec) { queryUses := func(name string) bool { for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if queryValueUses(name, q.Ret) { @@ -144,7 +144,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map } for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if q.Cmd == ":one" { From eaae4615626a3832ab41c10e0d5fad3ec7c48514 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 19:47:38 +0200 Subject: [PATCH 8/8] sqlc.embed: skip object if id is None --- internal/config.go | 5 ++++- internal/gen.go | 32 +++++++++++++++++++------------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/internal/config.go b/internal/config.go index ccf5a48..2a9bcdb 100644 --- a/internal/config.go +++ b/internal/config.go @@ -9,7 +9,10 @@ type Config struct { InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` TablePrefix string `json:"table_prefix"` // When a query uses a table with RLS enforced fields, it will be required to - // parametrized those fields. Associate tables are not covered! + // parametrized those fields. Not covered: + // - Associate tables + // - sqlc.embed() + // - json_agg(tbl.*) RLSEnforcedFields []string `json:"rls_enforced_fields"` // Merge queries defined in different files into one output queries.py file MergeQueryFiles bool `json:"merge_query_files"` diff --git a/internal/gen.go b/internal/gen.go index 5877a9c..8454a7d 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -118,27 +118,33 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { var embedFields []*pyast.Keyword for _, embed := range f.EmbedFields { embedFields = append(embedFields, &pyast.Keyword{ - Arg: embed.Name, - Value: subscriptNode( - rowVar, - constantInt(idx), - ), + Arg: embed.Name, + Value: subscriptNode(rowVar, constantInt(idx)), }) idx++ } val = &pyast.Node{ - Node: &pyast.Node_Call{ - Call: &pyast.Call{ - Func: f.Type.Annotation(false), - Keywords: embedFields, + Node: &pyast.Node_Compare{ + Compare: &pyast.Compare{ + Left: &pyast.Node{ + Node: &pyast.Node_Call{ + Call: &pyast.Call{ + Func: f.Type.Annotation(false), + Keywords: embedFields, + }, + }, + }, + Ops: []*pyast.Node{ + poet.Name(fmt.Sprintf("if row[%d] else", idx-len(f.EmbedFields))), + }, + Comparators: []*pyast.Node{ + poet.Constant(nil), + }, }, }, } } else { - val = subscriptNode( - rowVar, - constantInt(idx), - ) + val = subscriptNode(rowVar, constantInt(idx)) idx++ } call.Keywords = append(call.Keywords, &pyast.Keyword{