10000 Annotate TypedDicts fields with original pydantic type · cortea-ai/sqlc-gen-python@1c14ef1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1c14ef1

Browse files
committed
Annotate TypedDicts fields with original pydantic type
1 parent 0661793 commit 1c14ef1

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

internal/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ type Config struct {
2121

2222
const MODELS_FILENAME = "db_models"
2323
const ENUMS_FILENAME = "db_enums"
24-
const TYPED_DICTS_FILENAME = "db_typed_dicts"
24+
const MODEL_DICTS_FILENAME = "db_model_dicts"
25+
const QUERY_DICTS_FILENAME = "db_query_dicts"

internal/gen.go

Lines changed: 22 additions & 12 deletions
1385
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10081008
std["pydantic_base_class"] = i.importPydanticBaseClass()
10091009
std["pydantic.Field"] = importSpec{Module: "pydantic", Name: "Field"}
10101010
std[ENUMS_FILENAME] = importSpec{Module: ".", Name: ENUMS_FILENAME}
1011-
std[TYPED_DICTS_FILENAME] = importSpec{Module: ".", Name: TYPED_DICTS_FILENAME + " as dct"}
1011+
std[MODEL_DICTS_FILENAME] = importSpec{Module: ".", Name: MODEL_DICTS_FILENAME + " as dct"}
10121012
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
10131013

10141014
for _, m := range ctx.Models {
@@ -1040,7 +1040,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10401040
return &pyast.Node{Node: &pyast.Node_Module{Module: mod}}
10411041
}
10421042

1043-
func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1043+
func buildModelTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10441044
mod := moduleNode(ctx.SqlcVersion, "")
10451045
std, pkg := i.modelImportSpecs()
10461046
std["pydantic.Field"] = importSpec{Module: "pydantic", Name: "Field"}
@@ -1063,6 +1063,18 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10631063
})
10641064
}
10651065

1066+
return &pyast.Node{Node: &pyast.Node_Module{Module: mod}}
1067+
}
1068+
1069+
func buildQueryTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1070+
mod := moduleNode(ctx.SqlcVersion, "")
1071+
std, pkg := i.modelImportSpecs()
1072+
std["pydantic.Field"] = importSpec{Module: "pydantic", Name: "Field"}
1073+
std["typing.TypedDict"] = importSpec{Module: "typing", Name: "TypedDict"}
1074+
std[ENUMS_FILENAME] = importSpec{Module: ".", Name: ENUMS_FILENAME}
1075+
std[MODELS_FILENAME] = importSpec{Module: ".", Name: MODELS_FILENAME}
1076+
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
1077+
10661078
for _, q := range ctx.Queries {
10671079
if !ctx.OutputQuery(q.SourceName) {
10681080
continue
@@ -1071,9 +1083,6 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10711083
if arg.EmitStruct() {
10721084
def := typedDictNode(arg.Struct.Name)
10731085
for _, f := range arg.Struct.Fields {
1074-
if strings.HasPrefix(f.Type.InnerType, MODELS_FILENAME+".") {
1075-
f.Type.InnerType = strings.TrimPrefix(f.Type.InnerType, MODELS_FILENAME+".")
1076-
}
10771086
def.Body = append(def.Body, fieldNode(f))
10781087
}
10791088
mod.Body = append(mod.Body, poet.Node(def))
@@ -1082,9 +1091,6 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
10821091
if q.Ret.EmitStruct() {
10831092
def := typedDictNode(q.Ret.Struct.Name)
10841093
for _, f := range q.Ret.Struct.Fields {
1085-
if strings.HasPrefix(f.Type.InnerType, MODELS_FILENAME+".") {
1086-
f.Type.InnerType = strings.TrimPrefix(f.Type.InnerType, MODELS_FILENAME+".")
1087-
}
10881094
def.Body = append(def.Body, fieldNode(f))
10891095
}
10901096
mod.Body = append(mod.Body, poet.Node(def))
@@ -1174,7 +1180,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
11741180
mod := moduleNode(ctx.SqlcVersion, source)
11751181
std, pkg := i.queryImportSpecs(source)
11761182
std[ENUMS_FILENAME] = importSpec{Module: ".", Name: ENUMS_FILENAME}
1177-
std[TYPED_DICTS_FILENAME] = importSpec{Module: ".", Name: TYPED_DICTS_FILENAME + " as dct"}
1183+
std[QUERY_DICTS_FILENAME] = importSpec{Module: ".", Name: QUERY_DICTS_FILENAME + " as dct"}
11781184
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
11791185
mod.Body = append(mod.Body, &pyast.Node{
11801186
Node: &pyast.Node_ImportGroup{
@@ -1376,9 +1382,13 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
13761382
tctx.SourceName = MODELS_FILENAME + ".py"
13771383
output[MODELS_FILENAME+".py"] = string(result.Python)
13781384

1379-
result = pyprint.Print(buildTypedDictsTree(&tctx, i), pyprint.Options{})
1380-
tctx.SourceName = TYPED_DICTS_FILENAME + ".py"
1381-
output[TYPED_DICTS_FILENAME+".py"] = string(result.Python)
+
result = pyprint.Print(buildModelTypedDictsTree(&tctx, i), pyprint.Options{})
1386+
tctx.SourceName = MODEL_DICTS_FILENAME + ".py"
1387+
output[MODEL_DICTS_FILENAME+".py"] = string(result.Python)
1388+
1389+
result = pyprint.Print(buildQueryTypedDictsTree(&tctx, i), pyprint.Options{})
1390+
tctx.SourceName = QUERY_DICTS_FILENAME + ".py"
1391+
output[QUERY_DICTS_FILENAME+".py"] = string(result.Python)
13821392

13831393
files := map[string]struct{}{}
13841394
if i.C.MergeQueryFiles {

0 commit comments

Comments
 (0)
0