8000 RLS enforced fields · sqlc-dev/sqlc-gen-python@b376daa · GitHub
[go: up one dir, main page]

Skip to content

Commit b376daa

Browse files
committed
RLS enforced fields
1 parent abec3c8 commit b376daa

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

internal/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ type Config struct {
88
QueryParameterLimit *int32 `json:"query_parameter_limit"`
99
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
1010
TablePrefix string `json:"table_prefix"`
11+
// When a query uses a table with RLS enforced fields, it will be required to
12+
// parametrized those fields. Associate tables are not covered!
13+
RLSEnforcedFields []string `json:"rls_enforced_fields"`
1114
}

internal/gen.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,20 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
388388
}
389389

390390
func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) {
391+
rlsFieldsByTable := make(map[string][]string) // TODO
392+
if len(conf.RLSEnforcedFields) > 0 {
393+
for i := range structs {
394+
tableName := structs[i].Table.Name
395+
for _, f := range structs[i].Fields {
396+
for _, enforced := range conf.RLSEnforcedFields {
397+
if f.Name == enforced {
398+
rlsFieldsByTable[tableName] = append(rlsFieldsByTable[tableName], f.Name)
399+
}
400+
}
401+
}
402+
}
403+
}
404+
391405
qs := make([]Query, 0, len(req.Queries))
392406
for _, query := range req.Queries {
393407
if query.Name == "" {
@@ -419,9 +433,20 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
419433
if qpl < 0 {
420434
return nil, errors.New("invalid query parameter limit")
421435
}
436+
enforcedFields := make(map[string]bool)
437+
for _, c := range query.Columns {
438+
if fields, ok := rlsFieldsByTable[c.GetTable().GetName()]; ok {
439+
for _, f := range fields {
440+
enforcedFields[f] = false
441+
}
442+
}
443+
}
422444
if len(query.Params) > qpl || qpl == 0 {
423445
var cols []pyColumn
424446
for _, p := range query.Params {
447+
if _, ok := enforcedFields[p.GetColumn().GetName()]; ok {
448+
enforcedFields[p.Column.Name] = true
449+
}
425450
cols = append(cols, pyColumn{
426451
id: p.Number,
427452
Column: p.Column,
@@ -435,14 +460,21 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
435460
} else {
436461
args := make([]QueryValue, 0, len(query.Params))
437462
for _, p := range query.Params {
463+
if _, ok := enforcedFields[p.GetColumn().GetName()]; ok {
464+
enforcedFields[p.Column.Name] = true
465+
}
438466
args = append(args, QueryValue{
439467
Name: paramName(p),
440468
Typ: makePyType(req, p.Column),
441469
})
442470
}
443471
gq.Args = args
444472
}
445-
473+
for field, is_enforced := range enforcedFields {
474+
if !is_enforced {
475+
return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name)
476+
}
477+
}
446478
if len(query.Columns) == 1 {
447479
c := query.Columns[0]
448480
gq.Ret = QueryValue{

0 commit comments

Comments
 (0)
0