@@ -1008,7 +1008,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1008
1008
std ["pydantic_base_class" ] = i .importPydanticBaseClass ()
1009
1009
std ["pydantic.Field" ] = importSpec {Module : "pydantic" , Name : "Field" }
1010
1010
std [ENUMS_FILENAME ] = importSpec {Module : "." , Name : ENUMS_FILENAME }
1011
- std [TYPED_DICTS_FILENAME ] = importSpec {Module : "." , Name : TYPED_DICTS_FILENAME + " as dct" }
1011
+ std [MODEL_DICTS_FILENAME ] = importSpec {Module : "." , Name : MODEL_DICTS_FILENAME + " as dct" }
1012
1012
mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
1013
1013
1014
1014
for _ , m := range ctx .Models {
@@ -1040,7 +1040,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1040
1040
return & pyast.Node {Node : & pyast.Node_Module {Module : mod }}
1041
1041
}
1042
1042
1043
- func buildTypedDictsTree (ctx * pyTmplCtx , i * importer ) * pyast.Node {
1043
+ func buildModelTypedDictsTree (ctx * pyTmplCtx , i * importer ) * pyast.Node {
1044
1044
mod := moduleNode (ctx .SqlcVersion , "" )
1045
1045
std , pkg := i .modelImportSpecs ()
1046
1046
std ["pydantic.Field" ] = importSpec {Module : "pydantic" , Name : "Field" }
@@ -1063,6 +1063,18 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1063
1063
})
1064
1064
}
1065
1065
1066
+ return & pyast.Node {Node : & pyast.Node_Module {Module : mod }}
1067
+ }
1068
+
1069
+ func buildQueryTypedDictsTree (ctx * pyTmplCtx , i * importer ) * pyast.Node {
1070
+ mod := moduleNode (ctx .SqlcVersion , "" )
1071
+ std , pkg := i .modelImportSpecs ()
1072
+ std ["pydantic.Field" ] = importSpec {Module : "pydantic" , Name : "Field" }
1073
+ std ["typing.TypedDict" ] = importSpec {Module : "typing" , Name : "TypedDict" }
1074
+ std [ENUMS_FILENAME ] = importSpec {Module : "." , Name : ENUMS_FILENAME }
1075
+ std [MODELS_FILENAME ] = importSpec {Module : "." , Name : MODELS_FILENAME }
1076
+ mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
1077
+
1066
1078
for _ , q := range ctx .Queries {
1067
1079
if ! ctx .OutputQuery (q .SourceName ) {
1068
1080
continue
@@ -1071,9 +1083,6 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1071
1083
if arg .EmitStruct () {
1072
1084
def := typedDictNode (arg .Struct .Name )
1073
1085
for _ , f := range arg .Struct .Fields {
1074
- if strings .HasPrefix (f .Type .InnerType , MODELS_FILENAME + "." ) {
1075
- f .Type .InnerType = strings .TrimPrefix (f .Type .InnerType , MODELS_FILENAME + "." )
1076
- }
1077
1086
def .Body = append (def .Body , fieldNode (f ))
1078
1087
}
1079
1088
mod .Body = append (mod .Body , poet .Node (def ))
@@ -1082,9 +1091,6 @@ func buildTypedDictsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
1082
1091
if q .Ret .EmitStruct () {
1083
1092
def := typedDictNode (q .Ret .Struct .Name )
1084
1093
for _ , f := range q .Ret .Struct .Fields {
1085
- if strings .HasPrefix (f .Type .InnerType , MODELS_FILENAME + "." ) {
1086
- f .Type .InnerType = strings .TrimPrefix (f .Type .InnerType , MODELS_FILENAME + "." )
1087
- }
1088
1094
def .Body = append (def .Body , fieldNode (f ))
1089
1095
}
1090
1096
mod .Body = append (mod .Body , poet .Node (def ))
@@ -1174,7 +1180,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1174
1180
mod := moduleNode (ctx .SqlcVersion , source )
1175
1181
std , pkg := i .queryImportSpecs (source )
1176
1182
std [ENUMS_FILENAME ] = importSpec {Module : "." , Name : ENUMS_FILENAME }
1177
- std [TYPED_DICTS_FILENAME ] = importSpec {Module : "." , Name : TYPED_DICTS_FILENAME + " as dct" }
1183
+ std [QUERY_DICTS_FILENAME ] = importSpec {Module : "." , Name : QUERY_DICTS_FILENAME + " as dct" }
1178
1184
mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
1179
1185
mod .Body = append (mod .Body , & pyast.Node {
1180
1186
Node : & pyast.Node_ImportGroup {
@@ -1376,9 +1382,13 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
1376
1382
tctx .SourceName = MODELS_FILENAME + ".py"
1377
1383
output [MODELS_FILENAME + ".py" ] = string (result .Python )
1378
1384
1379
- result = pyprint .Print (buildTypedDictsTree (& tctx , i ), pyprint.Options {})
1380
- tctx .SourceName = TYPED_DICTS_FILENAME + ".py"
1381
- output [TYPED_DICTS_FILENAME + ".py" ] = string (result .Python )
1385
+ result = pyprint .Print (buildModelTypedDictsTree (& tctx , i ), pyprint.Options {})
1386
+ tctx .SourceName = MODEL_DICTS_FILENAME + ".py"
1387
+ output [MODEL_DICTS_FILENAME + ".py" ] = string (result .Python )
1388
+
1389
+ result = pyprint .Print (buildQueryTypedDictsTree (& tctx , i ), pyprint.Options {})
1390
+ tctx .SourceName = QUERY_DICTS_FILENAME + ".py"
1391
+ output [QUERY_DICTS_FILENAME + ".py" ] = string (result .Python )
1382
1392
1383
1393
files := map [string ]struct {}{}
1384
1394
if i .C .MergeQueryFiles {
0 commit comments