@@ -22,6 +22,7 @@ import (
22
22
"strings"
23
23
24
24
"github.com/open-policy-agent/opa/ast"
25
+ "github.com/open-policy-agent/opa/internal/jwx/jwa"
25
26
"github.com/open-policy-agent/opa/internal/jwx/jwk"
26
27
"github.com/open-policy-agent/opa/internal/jwx/jws"
27
28
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error {
268
269
return fmt .Errorf ("ECDSA signature verification error" )
269
270
}
270
271
271
- // getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
272
+ type verificationKey struct {
273
+ alg string
274
+ kid string
275
+ key interface {}
276
+ }
277
+
278
+ // getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
272
279
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
273
- func getKeyFromCertOrJWK (certificate string ) ([]interface {}, error ) {
280
+ // When provided a JWKS, each key additionally likely contains a key ID and the key algorithm.
281
+ func getKeysFromCertOrJWK (certificate string ) ([]verificationKey , error ) {
274
282
if block , rest := pem .Decode ([]byte (certificate )); block != nil {
275
283
if len (rest ) > 0 {
276
284
return nil , fmt .Errorf ("extra data after a PEM certificate block" )
@@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
281
289
if err != nil {
282
290
return nil , fmt .Errorf ("failed to parse a PEM certificate: %w" , err )
283
291
}
284
-
285
- return []interface {}{cert .PublicKey }, nil
292
+ return []verificationKey {{key : cert .PublicKey }}, nil
286
293
}
287
294
288
295
if block .Type == "PUBLIC KEY" {
@@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
291
298
return nil , fmt .Errorf ("failed to parse a PEM public key: %w" , err )
292
299
}
293
300
294
- return []interface {} {key }, nil
301
+ return []verificationKey { {key : key } }, nil
295
302
}
296
303
297
304
return nil , fmt .Errorf ("failed to extract a Key from the PEM certificate" )
@@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
302
309
return nil , fmt .Errorf ("failed to parse a JWK key (set): %w" , err )
303
310
}
304
311
305
- var keys []interface {}
312
+ var keys []verificationKey
306
313
for _ , k := range jwks .Keys {
307
314
key , err := k .Materialize ()
308
315
if err != nil {
309
316
return nil , err
310
317
}
311
- keys = append (keys , key )
318
+ keys = append (keys , verificationKey {
319
+ alg : k .GetAlgorithm ().String (),
320
+ kid : k .GetKeyID (),
321
+ key : key ,
322
+ })
312
323
}
313
324
314
325
return keys , nil
315
326
}
316
327
328
+ func getKeyByKid (kid string , keys []verificationKey ) * verificationKey {
329
+ for _ , key := range keys {
330
+ if key .kid == kid {
331
+ return & key
332
+ }
333
+ }
334
+ return nil
335
+ }
336
+
317
337
// Implements JWT signature verification.
318
338
func builtinJWTVerify (a ast.Value , b ast.Value , hasher func () hash.Hash , verify func (publicKey interface {}, digest []byte , signature []byte ) error ) (ast.Value , error ) {
319
339
token , err := decodeJWT (a )
@@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
326
346
return nil , err
327
347
}
328
348
329
- keys , err := getKeyFromCertOrJWK (string (s ))
349
+ keys , err := getKeysFromCertOrJWK (string (s ))
330
350
if err != nil {
331
351
return nil , err
332
352
}
@@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
336
356
return nil , err
337
357
}
338
358
359
+ err = token .decodeHeader ()
360
+ if err != nil {
361
+ return nil , err
362
+ }
363
+ header , err := parseTokenHeader (token )
364
+ if err != nil {
365
+ return nil , err
366
+ }
367
+
339
368
// Validate the JWT signature
340
- for _ , key := range keys {
341
- err = verify (key ,
342
- getInputSHA ([]byte (token .header + "." + token .payload ), hasher ),
343
- []byte (signature ))
344
369
345
- if err == nil {
346
- return ast .Boolean (true ), nil
370
+ // First, check if there's a matching key ID (`kid`) in both token header and key(s).
371
+ // If a match is found, verify using only that key. Only applicable when a JWKS was provided.
372
+ if header .kid != "" {
373
+ if key := getKeyByKid (header .kid , keys ); key != nil {
374
+ err = verify (key .key , getInputSHA ([]byte (token .header + "." + token .payload ), hasher ), []byte (signature ))
375
+
376
+ return ast .Boolean (err == nil ), nil
377
+ }
378
+ }
379
+
380
+ // If no key ID matched, try to verify using any key in the set
381
+ // If an alg is present in both the JWT header and the key, skip verification unless they match
382
+ for _ , key := range keys {
383
+ if key .alg == "" {
384
+ // No algorithm provided for the key - this is likely a certificate and not a JWKS, so
385
+ // we'll need to verify to find out
386
+ err = verify (key .key , getInputSHA ([]byte (token .header + "." + token .payload ), hasher ), []byte (signature ))
387
+ if err == nil {
388
+ return ast .Boolean (true ), nil
389
+ }
390
+ } else {
391
+ if header .alg != key .alg {
392
+ continue
393
+ }
394
+ err = verify (key .key , getInputSHA ([]byte (token .header + "." + token .payload ), hasher ), []byte (signature ))
395
+ if err == nil {
396
+ return ast .Boolean (true ), nil
397
+ }
347
398
}
348
399
}
349
400
@@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast
445
496
// tokenConstraints holds decoded JWT verification constraints.
446
497
type tokenConstraints struct {
447
498
// The set of asymmetric keys we can verify with.
448
- keys []interface {}
499
+ keys []verificationKey
449
500
450
501
// The single symmetric key we will verify with.
451
502
secret string
@@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error {
495
546
return fmt .Errorf ("cert constraint: must be a string" )
496
547
}
497
548
498
- keys , err := getKeyFromCertOrJWK (string (s ))
549
+ keys , err := getKeysFromCertOrJWK (string (s ))
499
550
if err != nil {
500
551
return err
501
552
}
553
+
502
554
constraints .keys = keys
503
555
return nil
504
556
}
@@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
595
647
}
596
648
// If we're configured with asymmetric key(s) then only trust that
597
649
if constraints .keys != nil {
650
+ if kid != "" {
651
+ if key := getKeyByKid (kid , constraints .keys ); key != nil {
652
+ err := a .verify (key .key , a .hash , plaintext , []byte (signature ))
653
+ if err != nil {
654
+ return errSignatureNotVerified
655
+ }
656
+ return nil
657
+ }
658
+ }
659
+
598
660
verified := false
599
661
for _ , key := range constraints .keys {
600
- err := a .verify (key , a .hash , plaintext , []byte (signature ))
601
- if err == nil {
602
- verified = true
603
- break
662
+ if key .alg == "" {
663
+ err := a .verify (key .key , a .hash , plaintext , []byte (signature ))
664
+ if err == nil {
665
+ verified = true
666
+ break
667
+ }
668
+ } else {
669
+ if alg != key .alg {
670
+ continue
671
+ }
672
+ err := a .verify (key .key , a .hash , plaintext , []byte (signature ))
673
+ if err == nil {
674
+ verified = true
675
+ break
676
+ }
604
677
}
605
678
}
679
+
606
680
if ! verified {
607
681
return errSignatureNotVerified
608
682
}
@@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j
843
917
return err
844
918
}
845
919
alg := standardHeaders .GetAlgorithm ()
920
+ if alg == jwa .Unsupported {
921
+ return fmt .Errorf ("unknown signature algorithm" )
922
+ }
846
923
847
924
if (standardHeaders .Type == "" || standardHeaders .Type == headerJwt ) && ! json .Valid ([]byte (jwsPayload )) {
848
925
return fmt .Errorf ("type is JWT but payload is not JSON" )
0 commit comments