|
1 | 1 | package oidctest
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
4 | 5 | "database/sql"
|
5 | 6 | "encoding/json"
|
6 | 7 | "net/http"
|
| 8 | + "net/url" |
7 | 9 | "testing"
|
8 | 10 | "time"
|
9 | 11 |
|
10 | 12 | "github.com/golang-jwt/jwt/v4"
|
11 | 13 | "github.com/stretchr/testify/require"
|
| 14 | + "golang.org/x/xerrors" |
12 | 15 |
|
13 | 16 | "github.com/coder/coder/v2/coderd/database"
|
14 | 17 | "github.com/coder/coder/v2/coderd/database/dbauthz"
|
@@ -114,3 +117,51 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders
|
114 | 117 | _, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
|
115 | 118 | require.NoError(t, err, "user must be able to be fetched")
|
116 | 119 | }
|
| 120 | + |
| 121 | +// OAuth2GetCode emulates a user clicking "allow" on the IDP page. When doing |
| 122 | +// unit tests, it's easier to skip this step sometimes. It does make an actual |
| 123 | +// request to the IDP, so it should be equivalent to doing this "manually" with |
| 124 | +// actual requests. |
| 125 | +// |
| 126 | +// TODO: Is state param optional? Can we grab it from the authURL? |
| 127 | +func OAuth2GetCode(authURL string, state string, doRequest func(req *http.Request) (*http.Response, error)) (string, error) { |
| 128 | + // We need to store some claims, because this is also an OIDC provider, and |
| 129 | + // it expects some claims to be present. |
| 130 | + // TODO: POST or GET method? |
| 131 | + r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authURL, nil) |
| 132 | + if err != nil { |
| 133 | + return "", xerrors.Errorf("failed to create auth request: %w", err) |
| 134 | + } |
| 135 | + |
| 136 | + expCode := http.StatusTemporaryRedirect |
| 137 | + resp, err := doRequest(r) |
| 138 | + if err != nil { |
| 139 | + return "", xerrors.Errorf("request: %w", err) |
| 140 | + } |
| 141 | + defer resp.Body.Close() |
| 142 | + |
| 143 | + if resp.StatusCode != expCode { |
| 144 | + return "", codersdk.ReadBodyAsError(resp) |
| 145 | + } |
| 146 | + |
| 147 | + to := resp.Header.Get("Location") |
| 148 | + if to == "" { |
| 149 | + return "", xerrors.Errorf("expected redirect location") |
| 150 | + } |
| 151 | + |
| 152 | + toURL, err := url.Parse(to) |
| 153 | + if err != nil { |
| 154 | + return "", xerrors.Errorf("failed to parse redirect location: %w", err) |
| 155 | + } |
| 156 | + |
| 157 | + code := toURL.Query().Get("code") |
| 158 | + if code == "" { |
| 159 | + return "", xerrors.Errorf("expected code in redirect location") |
| 160 | + } |
| 161 | + |
| 162 | + newState := toURL.Query().Get("state") |
| 163 | + if newState != state { |
| 164 | + return "", xerrors.Errorf("expected state %q, got %q", state, newState) |
| 165 | + } |
| 166 | + return code, nil |
| 167 | +} |
0 commit comments