8000 chore: implement device auth flow for fake idp by Emyrk · Pull Request #11707 · coder/coder · GitHub
[go: up one dir, main page]

Skip to content

chore: implement device auth flow for fake idp #11707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: implement device auth flow for fake idp
  • Loading branch information
Emyrk committed Jan 22, 2024
commit 3bd9d36e9d692b7192e5641483489bba078ed779
206 changes: 180 additions & 26 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -47,6 +49,13 @@ type token struct {
exp time.Time
}

type deviceFlow struct {
// userInput is the expected input to authenticate the device flow.
userInput string
exp time.Time
granted bool
}

// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
Expand Down Expand Up @@ -79,6 +88,9 @@ type FakeIDP struct {
refreshTokens *syncmap.Map[string, string]
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
// Device flow
deviceCode *syncmap.Map[string, deviceFlow]
deviceCodeInput *syncmap.Map[string, externalauth.ExchangeDeviceCodeResponse]

// hooks
// hookValidRedirectURL can be used to reject a redirect url from the
Expand Down Expand Up @@ -229,6 +241,7 @@ const (
keysPath = "/oauth2/keys"
userInfoPath = "/oauth2/userinfo"
deviceAuth = "/login/device/code"
deviceVerify = "/login/device"
)

func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
Expand All @@ -249,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
deviceCode: syncmap.New[string, deviceFlow](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
Expand Down Expand Up @@ -291,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
Algorithms: []string{
"RS256",
},
Expand Down Expand Up @@ -539,12 +554,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map

// ProviderJSON is the .well-known/configuration JSON
type ProviderJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
DeviceCodeURL string `json:"device_authorization_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
// This is custom
ExternalAuthURL string `json:"external_auth_url"`
}
Expand Down Expand Up @@ -712,8 +728,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
}))

mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
var values url.Values
var err error
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
values = r.URL.Query()
} else {
values, err = f.authenticateOIDCClientRequest(t, r)
}
f.logger.Info(r.Context(), "http idp call token",
slog.F("url", r.URL.String()),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
Expand Down Expand Up @@ -789,6 +812,35 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
f.refreshTokens.Delete(refreshToken)
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse
deviceCode := values.Get("device_code")
if deviceCode == "" {
resp.Error = "invalid_request"
resp.ErrorDescription = "missing device_code"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}

deviceFlow, ok := f.deviceCode.Load(deviceCode)
if !ok {
resp.Error = "invalid_request"
resp.ErrorDescription = "device_code provided not found"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}

if !deviceFlow.granted {
// Status code ok with the error as pending.
resp.Error = "authorization_pending"
resp.ErrorDescription = ""
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
return
}

// Would be nice to get an actual email here.
claims = jwt.MapClaims{
"email": "unknown-dev-auth",
}
default:
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
Expand All @@ -812,8 +864,19 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)

rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
return
}

// Default to form encode. Just to make sure our code sets the right headers.
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
vals := url.Values{}
for k, v := range token {
vals.Set(k, fmt.Sprintf("%v", v))
}
_, _ = rw.Write([]byte(vals.Encode()))
}))

validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
Expand Down Expand Up @@ -891,10 +954,68 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
_ = json.NewEncoder(rw).Encode(set)
}))

mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device verify")

inputParam := "user_input"
userInput := r.URL.Query().Get(inputParam)
if userInput == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user input",
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
})
return
}

deviceCode := r.URL.Query().Get("device_code")
if deviceCode == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Hit this url again with ?device_code=<device_code>",
})
return
}

flow, ok := f.deviceCode.Load(deviceCode)
if !ok {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code not found.",
})
return
}

if time.Now().After(flow.exp) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code expired.",
})
return
}

if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "user code does not match",
})
return
}

f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: flow.exp,
granted: true,
})
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
Message: "Device authenticated!",
})
}))

mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device auth")

p := httpapi.NewQueryParamParser()
p.Required("client_id")
p.Required("scopes")
clientID := p.String(r.URL.Query(), "", "client_id")
_ = p.String(r.URL.Query(), "", "scopes")
if len(p.Errors) > 0 {
Expand All @@ -912,24 +1033,42 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}

deviceCode := uuid.NewString()
lifetime := time.Second * 900
flow := deviceFlow{
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
}
f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: time.Now().Add(lifetime),
})

verifyURL := f.issuerURL.ResolveReference(&url.URL{
Path: deviceVerify,
RawQuery: url.Values{
"device_code": {deviceCode},
"user_input": {flow.userInput},
}.Encode(),
}).String()

if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
"device_code": uuid.NewString(),
"user_code": "1234",
"verification_uri": "",
"expires_in": 900,
"interval": 0,
"device_code": deviceCode,
"user_code": flow.userInput,
"verification_uri": verifyURL,
"expires_in": int(lifetime.Seconds()),
"interval": 3,
})
return
}

// By default, GitHub form encodes these.
_, _ = fmt.Fprint(rw, url.Values{
"device_code": {uuid.NewString()},
"user_code": {"1234"},
"verification_uri": {""},
"expires_in": {"900"},
"interval": {"0"},
"device_code": {deviceCode},
"user_code": {flow.userInput},
"verification_uri": {verifyURL},
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
"interval": {"3"},
})
}))

Expand Down Expand Up @@ -1034,6 +1173,8 @@ type ExternalAuthConfigOptions struct {
// co F438 mpletely customize the response. It captures all routes under the /external-auth-validate/*
// so the caller can do whatever they want and even add routes.
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)

UseDeviceAuth bool
}

func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
Expand Down Expand Up @@ -1080,17 +1221,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
}
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
cfg := &externalauth.Config{
DisplayName: id,
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
InstrumentedOAuth2Config: oauthCfg,
ID: id,
// No defaults for these fields by omitting the type
Type: "",
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
DeviceAuth: &externalauth.DeviceAuth{
Config: oauthCfg,
ClientID: f.clientID,
TokenURL: f.provider.TokenURL,
Scopes: []string{},
CodeURL: f.provider.DeviceCodeURL,
},
}

if !custom.UseDeviceAuth {
cfg.DeviceAuth = nil
}

for _, opt := range opts {
opt(cfg)
}
Expand Down
10 changes: 8 additions & 2 deletions scripts/testidp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
deviceFlow = flag.Bool("device-flow", false, "Enable device flow")
// By default, no regex means it will never match anything. So at least default to matching something.
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
)
Expand Down Expand Up @@ -66,14 +67,18 @@ func RunIDP() func(t *testing.T) {
id, sec := idp.AppCredentials()
prov := idp.WellknownConfig()
const appID = "fake"
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
coderCfg := idp.ExternalAuthConfig(t, appID, &oidctest.ExternalAuthConfigOptions{
UseDeviceAuth: *deviceFlow,
})

log.Println("IDP Issuer URL", idp.IssuerURL())
log.Println("Coderd Flags")

deviceCodeURL := ""
if coderCfg.DeviceAuth != nil {
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
}

cfg := withClientSecret{
ClientSecret: sec,
ExternalAuthConfig: codersdk.ExternalAuthConfig{
Expand All @@ -89,13 +94,14 @@ func RunIDP() func(t *testing.T) {
NoRefresh: false,
Scopes: []string{"openid", "email", "profile"},
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
DeviceFlow: coderCfg.DeviceAuth != nil,
DeviceFlow: *deviceFlow,
DeviceCodeURL: deviceCodeURL,
Regex: *extRegex,
DisplayName: coderCfg.DisplayName,
DisplayIcon: coderCfg.DisplayIcon,
},
}

data, err := json.Marshal([]withClientSecret{cfg})
require.NoError(t, err)
log.Printf(`--external-auth-providers='%s'`, string(data))
Expand Down
0