@@ -388,6 +388,20 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
388
388
}
389
389
390
390
func buildQueries (conf Config , req * plugin.GenerateRequest , structs []Struct ) ([]Query , error ) {
391
+ rlsFieldsByTable := make (map [string ][]string ) // TODO
392
+ if len (conf .RLSEnforcedFields ) > 0 {
393
+ for i := range structs {
394
+ tableName := structs [i ].Table .Name
395
+ for _ , f := range structs [i ].Fields {
396
+ for _ , enforced := range conf .RLSEnforcedFields {
397
+ if f .Name == enforced {
398
+ rlsFieldsByTable [tableName ] = append (rlsFieldsByTable [tableName ], f .Name )
399
+ }
400
+ }
401
+ }
402
+ }
403
+ }
404
+
391
405
qs := make ([]Query , 0 , len (req .Queries ))
392
406
for _ , query := range req .Queries {
393
407
if query .Name == "" {
@@ -419,9 +433,20 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
419
433
if qpl < 0 {
420
434
return nil , errors .New ("invalid query parameter limit" )
421
435
}
436
+ enforcedFields := make (map [string ]bool )
437
+ for _ , c := range query .Columns {
438
+ if fields , ok := rlsFieldsByTable [c .GetTable ().GetName ()]; ok {
439
+ for _ , f := range fields {
440
+ enforcedFields [f ] = false
441
+ }
442
+ }
443
+ }
422
444
if len (query .Params ) > qpl || qpl == 0 {
423
445
var cols []pyColumn
424
446
for _ , p := range query .Params {
447
+ if _ , ok := enforcedFields [p .GetColumn ().GetName ()]; ok {
448
+ enforcedFields [p .Column .Name ] = true
449
+ }
425
450
cols = append (cols , pyColumn {
426
451
id : p .Number ,
427
452
Column : p .Column ,
@@ -435,14 +460,21 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
435
460
} else {
436
461
args := make ([]QueryValue , 0 , len (query .Params ))
437
462
for _ , p := range query .Params {
463
+ if _ , ok := enforcedFields [p .GetColumn ().GetName ()]; ok {
464
+ enforcedFields [p .Column .Name ] = true
465
+ }
438
466
args = append (args , QueryValue {
439
467
Name : paramName (p ),
440
468
Typ : makePyType (req , p .Column ),
441
469
})
442
470
}
443
471
gq .Args = args
444
472
}
445
-
473
+ for field , is_enforced := range enforcedFields {
474
+ if ! is_enforced {
475
+ return nil , fmt .Errorf ("RLS field %s is not filtered in query %s" , field , query .Name )
476
+ }
477
+ }
446
478
if len (query .Columns ) == 1 {
447
479
c := query .Columns [0 ]
448
480
gq .Ret = QueryValue {
0 commit comments