8000 Ignore keys of unknown alg when verifying JWTs with JWKS (#4725) · open-policy-agent/opa@cb6a4c0 · GitHub
  • [go: up one dir, main page]

    Skip to content

    Commit cb6a4c0

    Browse files
    authored
    Ignore keys of unknown alg when verifying JWTs with JWKS (#4725)
    Additionally, improve the JWT verification process when a JWKS is provided: * If `kid` is present in JWT header, and exists in JWKS — verify using that key only. * If `kid` not present in JWT header, try verification only using keys matching the `alg` provided in the JWT header (mandatory claim). Fixes #4699 Signed-off-by: Anders Eknert <anders@eknert.com>
    1 parent 1889f24 commit cb6a4c0

    File tree

    6 files changed

    +319
    -29
    lines changed

    6 files changed

    +319
    -29
    lines changed

    internal/jwx/jwa/signature.go

    Lines changed: 3 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -27,6 +27,7 @@ const (
    2727
    RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
    2828
    RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
    2929
    NoValue SignatureAlgorithm = "" // No value is different from none
    30+
    Unsupported SignatureAlgorithm = "unsupported"
    3031
    )
    3132

    3233
    // Accept is used when conversion from values given by
    @@ -69,7 +70,8 @@ func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
    6970
    }
    7071
    _, ok := signatureAlg[quoted]
    7172
    if !ok {
    72-
    return errors.New("unknown signature algorithm")
    73+
    *signature = Unsupported
    74+
    return nil
    7375
    }
    7476
    *signature = SignatureAlgorithm(quoted)
    7577
    return nil

    internal/jwx/jwk/jwk.go

    Lines changed: 3 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -99,6 +99,9 @@ func parse(jwkSrc string) (*Set, error) {
    9999
    } else {
    100100
    for i := range rawKeySetJSON.Keys {
    101101
    rawKeyJSON := rawKeySetJSON.Keys[i]
    102+
    if rawKeyJSON.Algorithm != nil && *rawKeyJSON.Algorithm == jwa.Unsupported {
    103+
    continue
    104+
    }
    102105
    jwkKey, err = rawKeyJSON.GenerateKey()
    103106
    if err != nil {
    104107
    return nil, fmt.Errorf("failed to generate key: %w", err)

    internal/jwx/jws/headers_test.go

    Lines changed: 5 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -112,9 +112,11 @@ func TestHeader(t *testing.T) {
    112112
    headers := `{"typ":"JWT",` + "\r\n" + ` "alg":"dummy"}`
    113113
    var standardHeaders jws.StandardHeaders
    114114
    err := json.Unmarshal([]byte(headers), &standardHeaders)
    115-
    if err == nil {
    116-
    t.Fatal("Unmarshal should have failed")
    115+
    if err != nil {
    116+
    t.Fatal(err)
    117+
    }
    118+
    if standardHeaders.Algorithm != jwa.Unsupported {
    119+
    t.Errorf("expected unsupported algorithm")
    117120
    }
    118-
    119121
    })
    120122
    }

    internal/jwx/jws/jws_test.go

    Lines changed: 5 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -82,8 +82,11 @@ func TestAlgError(t *testing.T) {
    8282
    const hdr = `{"typ":"JWT",` + "\r\n" + ` "alg":"unknown"}`
    8383
    var standardHeaders jws.StandardHeaders
    8484
    err := json.Unmarshal([]byte(hdr), &standardHeaders)
    85-
    if err == nil {
    86-
    t.Fatal("header parsing should have failed")
    85+
    if err != nil {
    86+
    t.Fatal(err)
    87+
    }
    88+
    if standardHeaders.Algorithm != jwa.Unsupported {
    89+
    t.Errorf("expected unsupported algorithm")
    8790
    }
    8891
    })
    8992
    }

    topdown/tokens.go

    Lines changed: 97 additions & 20 deletions
    Original file line numberDiff line numberDiff line change
    @@ -22,6 +22,7 @@ import (
    2222
    "strings"
    2323

    2424
    "github.com/open-policy-agent/opa/ast"
    25+
    "github.com/open-policy-agent/opa/internal/jwx/jwa"
    2526
    "github.com/open-policy-agent/opa/internal/jwx/jwk"
    2627
    "github.com/open-policy-agent/opa/internal/jwx/jws"
    2728
    "github.com/open-policy-agent/opa/topdown/builtins"
    @@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error {
    268269
    return fmt.Errorf("ECDSA signature verification error")
    269270
    }
    270271

    271-
    // getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
    272+
    type verificationKey struct {
    273+
    alg string
    274+
    kid string
    275+
    key interface{}
    276+
    }
    277+
    278+
    // getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
    272279
    // A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
    273-
    func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
    280+
    // When provided a JWKS, each key additionally likely contains a key ID and the key algorithm.
    281+
    func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
    274282
    if block, rest := pem.Decode([]byte(certificate)); block != nil {
    275283
    if len(rest) > 0 {
    276284
    return nil, fmt.Errorf("extra data after a PEM certificate block")
    @@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
    281289
    if err != nil {
    282290
    return nil, fmt.Errorf("failed to parse a PEM certificate: %w", err)
    283291
    }
    284-
    285-
    return []interface{}{cert.PublicKey}, nil
    292+
    return []verificationKey{{key: cert.PublicKey}}, nil
    286293
    }
    287294

    288295
    if block.Type == "PUBLIC KEY" {
    @@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
    291298
    return nil, fmt.Errorf("failed to parse a PEM public key: %w", err)
    292299
    }
    293300

    294-
    return []interface{}{key}, nil
    301+
    return []verificationKey{{key: key}}, nil
    295302
    }
    296303

    297304
    return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
    @@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
    302309
    return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
    303310
    }
    304311

    305-
    var keys []interface{}
    312+
    var keys []verificationKey
    306313
    for _, k := range jwks.Keys {
    307314
    key, err := k.Materialize()
    308315
    if err != nil {
    309316
    return nil, err
    310317
    }
    311-
    keys = append(keys, key)
    318+
    keys = append(keys, verificationKey{
    319+
    alg: k.GetAlgorithm().String(),
    320+
    kid: k.GetKeyID(),
    321+
    key: key,
    322+
    })
    312323
    }
    313324

    314325
    return keys, nil
    315326
    }
    316327

    328+
    func getKeyByKid(kid string, keys []verificationKey) *verificationKey {
    329+
    for _, key := range keys {
    330+
    if key.kid == kid {
    331+
    return &key
    332+
    }
    333+
    }
    334+
    return nil
    335+
    }
    336+
    317337
    // Implements JWT signature verification.
    318338
    func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
    319339
    token, err := decodeJWT(a)
    @@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
    326346
    return nil, err
    327347
    }
    328348

    329-
    keys, err := getKeyFromCertOrJWK(string(s))
    349+
    keys, err := getKeysFromCertOrJWK(string(s))
    330350
    if err != nil {
    331351
    return nil, err
    332352
    }
    @@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
    336356
    return nil, err
    337357
    }
    338358

    359+
    err = token.decodeHeader()
    360+
    if err != nil {
    361+
    return nil, err
    362+
    }
    363+
    header, err := parseTokenHeader(token)
    364+
    if err != nil {
    365+
    return nil, err
    366+
    }
    367+
    339368
    // Validate the JWT signature
    340-
    for _, key := range keys {
    341-
    err = verify(key,
    342-
    getInputSHA([]byte(token.header+"."+token.payload), hasher),
    343-
    []byte(signature))
    344369

    345-
    if err == nil {
    346-
    return ast.Boolean(true), nil
    370+
    // First, check if there's a matching key ID (`kid`) in both token header and key(s).
    371+
    // If a match is found, verify using only that key. Only applicable when a JWKS was provided.
    372+
    if header.kid != "" {
    373+
    if key := getKeyByKid(header.kid, keys); key != nil {
    374+
    err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
    375+
    376+
    return ast.Boolean(err == nil), nil
    377+
    }
    378+
    }
    379+
    380+
    // If no key ID matched, try to verify using any key in the set
    381+
    // If an alg is present in both the JWT header and the key, skip verification unless they match
    382+
    for _, key := range keys {
    383+
    if key.alg == "" {
    384+
    // No algorithm provided for the key - this is likely a certificate and not a JWKS, so
    385+
    // we'll need to verify to find out
    386+
    err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
    387+
    if err == nil {
    388+
    return ast.Boolean(true), nil
    389+
    }
    390+
    } else {
    391+
    if header.alg != key.alg {
    392+
    continue
    393+
    }
    394+
    err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
    395+
    if err == nil {
    396+
    return ast.Boolean(true), nil
    397+
    }
    347398
    }
    348399
    }
    349400

    @@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast
    445496
    // tokenConstraints holds decoded JWT verification constraints.
    446497
    type tokenConstraints struct {
    447498
    // The set of asymmetric keys we can verify with.
    448-
    keys []interface{}
    499+
    keys []verificationKey
    449500

    450501
    // The single symmetric key we will verify with.
    451502
    secret string
    @@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error {
    495546
    return fmt.Errorf("cert constraint: must be a string")
    496547
    }
    497548

    498-
    keys, err := getKeyFromCertOrJWK(string(s))
    549+
    keys, err := getKeysFromCertOrJWK(string(s))
    499550
    if err != nil {
    500551
    return err
    501552
    }
    553+
    502554
    constraints.keys = keys
    503555
    return nil
    504556
    }
    @@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
    595647
    }
    596648
    // If we're configured with asymmetric key(s) then only trust that
    597649
    if constraints.keys != nil {
    650+
    if kid != "" {
    651+
    if key := getKeyByKid(kid, constraints.keys); key != nil {
    652+
    err := a.verify(key.key, a.hash, plaintext, []byte(signature))
    653+
    if err != nil {
    654+
    return errSignatureNotVerified
    655+
    }
    656+
    return nil
    657+
    }
    658+
    }
    659+
    598660
    verified := false
    599661
    for _, key := range constraints.keys {
    600-
    err := a.verify(key, a.hash, plaintext, []byte(signature))
    601-
    if err == nil {
    602-
    verified = true
    603-
    break
    662+
    if key.alg == "" {
    663+
    err := a.verify(key.key, a.hash, plaintext, []byte(signature))
    664+
    if err == nil {
    665+
    verified = true
    666+
    break
    667+
    }
    668+
    } else {
    669+
    if alg != key.alg {
    670+
    continue
    671+
    }
    672+
    err := a.verify(key.key, a.hash, plaintext, []byte(signature))
    673+
    if err == nil {
    674+
    verified = true
    675+
    break
    676+
    }
    604677
    }
    605678
    }
    679+
    606680
    if !verified {
    607681
    return errSignatureNotVerified
    608682
    }
    @@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j
    843917
    return err
    844918
    }
    845919
    alg := standardHeaders.GetAlgorithm()
    920+
    if alg == jwa.Unsupported {
    921+
    return fmt.Errorf("unknown signature algorithm")
    922+
    }
    846923

    847924
    if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
    848925
    return fmt.Errorf("type is JWT but payload is not JSON")

    0 commit comments

    Comments
     (0)
    0