Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 9 additions & 40 deletions auth/actionsoidc/actionsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ import (
"net/url"
"os"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
)

const (
Expand All @@ -52,77 +49,49 @@ const (
// bearer token are read from the EnvRequestURL and EnvRequestToken environment
// variables, which Actions injects into a job that has the 'id-token: write'
// permission.
//
// It returns the ID token and its expiration time. The expiration is read from
// the token's 'exp' claim without verifying the signature: the token was issued
// by the trusted endpoint one HTTP call ago, and the expiry is only used by
// callers to schedule re-minting.
func FetchToken(ctx context.Context, audience string) (string, time.Time, error) {
func FetchToken(ctx context.Context, audience string) (string, error) {
requestURL := os.Getenv(EnvRequestURL)
requestToken := os.Getenv(EnvRequestToken)
if requestURL == "" || requestToken == "" {
return "", time.Time{}, fmt.Errorf("%s and %s must be set in the environment; "+
return "", fmt.Errorf("%s and %s must be set in the environment; "+
"ensure the Actions job has the 'id-token: write' permission", EnvRequestURL, EnvRequestToken)
}

u, err := url.Parse(requestURL)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid %s: %w", EnvRequestURL, err)
return "", fmt.Errorf("invalid %s: %w", EnvRequestURL, err)
}
q := u.Query()
q.Set("audience", audience)
u.RawQuery = q.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to create OIDC token request: %w", err)
return "", fmt.Errorf("failed to create OIDC token request: %w", err)
}
req.Header.Set("Authorization", "bearer "+requestToken)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", time.Time{}, fmt.Errorf("OIDC token request failed: %w", err)
return "", fmt.Errorf("OIDC token request failed: %w", err)
}
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", time.Time{}, fmt.Errorf("OIDC token request failed with status %s: %s",
return "", fmt.Errorf("OIDC token request failed with status %s: %s",
resp.Status, strings.TrimSpace(string(body)))
}

var result struct {
Value string `json:"value"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", time.Time{}, fmt.Errorf("failed to decode OIDC token response: %w", err)
return "", fmt.Errorf("failed to decode OIDC token response: %w", err)
}
if result.Value == "" {
return "", time.Time{}, errors.New("the OIDC token response did not contain a token")
}

exp, err := tokenExpiry(result.Value)
if err != nil {
return "", time.Time{}, err
return "", errors.New("the OIDC token response did not contain a token")
}

return result.Value, exp, nil
}

// tokenExpiry extracts the 'exp' claim from a compact-serialized JWT without
// verifying its signature, matching how auth/generic reads service account
// token expirations.
func tokenExpiry(token string) (time.Time, error) {
tok, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse OIDC token: %w", err)
}
exp, err := tok.Claims.GetExpirationTime()
if err != nil {
return time.Time{}, fmt.Errorf("failed to get expiration time from OIDC token: %w", err)
}
if exp == nil {
return time.Time{}, errors.New("OIDC token has no exp claim")
}
return exp.Time, nil
return result.Value, nil
}
35 changes: 6 additions & 29 deletions auth/actionsoidc/actionsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ func makeJWT(t *testing.T, exp time.Time) string {
}

func TestFetchToken(t *testing.T) {
t.Run("fetches token and exp", func(t *testing.T) {
exp := time.Now().Add(time.Hour).Truncate(time.Second)
idToken := makeJWT(t, exp)
t.Run("fetches token", func(t *testing.T) {
idToken := makeJWT(t, time.Now().Add(time.Hour))

var gotAudience, gotAuth, gotPath string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -58,16 +57,13 @@ func TestFetchToken(t *testing.T) {
t.Setenv(actionsoidc.EnvRequestURL, srv.URL+"/token?api-version=2.0")
t.Setenv(actionsoidc.EnvRequestToken, "request-token")

token, gotExp, err := actionsoidc.FetchToken(context.Background(), "my-audience")
token, err := actionsoidc.FetchToken(context.Background(), "my-audience")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != idToken {
t.Errorf("token = %q, want %q", token, idToken)
}
if !gotExp.Equal(exp) {
t.Errorf("exp = %v, want %v", gotExp, exp)
}
if gotAudience != "my-audience" {
t.Errorf("audience query = %q, want my-audience", gotAudience)
}
Expand All @@ -82,7 +78,7 @@ func TestFetchToken(t *testing.T) {
t.Run("errors when env vars are unset", func(t *testing.T) {
t.Setenv(actionsoidc.EnvRequestURL, "")
t.Setenv(actionsoidc.EnvRequestToken, "")
_, _, err := actionsoidc.FetchToken(context.Background(), "aud")
_, err := actionsoidc.FetchToken(context.Background(), "aud")
if err == nil || !strings.Contains(err.Error(), actionsoidc.EnvRequestURL) {
t.Fatalf("expected error mentioning %s, got: %v", actionsoidc.EnvRequestURL, err)
}
Expand All @@ -97,7 +93,7 @@ func TestFetchToken(t *testing.T) {
t.Setenv(actionsoidc.EnvRequestURL, srv.URL)
t.Setenv(actionsoidc.EnvRequestToken, "request-token")

_, _, err := actionsoidc.FetchToken(context.Background(), "aud")
_, err := actionsoidc.FetchToken(context.Background(), "aud")
if err == nil || !strings.Contains(err.Error(), "denied") {
t.Fatalf("expected error containing response body, got: %v", err)
}
Expand All @@ -111,28 +107,9 @@ func TestFetchToken(t *testing.T) {
t.Setenv(actionsoidc.EnvRequestURL, srv.URL)
t.Setenv(actionsoidc.EnvRequestToken, "request-token")

_, _, err := actionsoidc.FetchToken(context.Background(), "aud")
_, err := actionsoidc.FetchToken(context.Background(), "aud")
if err == nil || !strings.Contains(err.Error(), "did not contain a token") {
t.Fatalf("expected empty-token error, got: %v", err)
}
})

t.Run("errors on token without exp", func(t *testing.T) {
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{})
noExp, err := tok.SignedString([]byte("test-secret"))
if err != nil {
t.Fatalf("failed to mint test JWT: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"value":"` + noExp + `"}`))
}))
defer srv.Close()
t.Setenv(actionsoidc.EnvRequestURL, srv.URL)
t.Setenv(actionsoidc.EnvRequestToken, "request-token")

_, _, err = actionsoidc.FetchToken(context.Background(), "aud")
if err == nil || !strings.Contains(err.Error(), "exp claim") {
t.Fatalf("expected missing-exp error, got: %v", err)
}
})
}
84 changes: 60 additions & 24 deletions auth/utils/cijwt/cijwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ limitations under the License.
// platform's OIDC integration or signing it locally.
//
// Each configured host gets its token one of four ways:
// - WithHostAudience mints an OIDC ID token for the given audience from the
// GitHub/Forgejo Actions token endpoint (see the actionsoidc package),
// caching it for the first 50% of its lifetime and reminting on demand.
// - WithHostTokenFunc invokes a caller-supplied function to obtain the JWT,
// then caches it for the first 50% of its 'exp' claim's remaining lifetime
// and re-invokes the function on demand. The caller decides where the
// token comes from (e.g. actionsoidc.FetchToken for the GitHub/Forgejo
// Actions endpoint, idtoken.NewTokenSource for GCP, etc.).
// - WithHostToken sends a static JWT as-is, e.g. a GitLab CI id_token injected
// into the job environment.
// - WithHostTokenFile reads the JWT from a file for every request, so a token
Expand All @@ -43,7 +45,8 @@ import (
"sync"
"time"

"github.com/fluxcd/pkg/auth/actionsoidc"
gojwt "github.com/golang-jwt/jwt/v5"

"github.com/fluxcd/pkg/auth/jwt"
)

Expand All @@ -65,11 +68,21 @@ type hostJWK struct {
sub string
}

// TokenFunc returns a fresh JWT. The Transport parses the returned token's
// 'exp' claim without verifying the signature and caches it for the first 50%
// of its remaining lifetime, re-invoking fn on demand.
type TokenFunc func(ctx context.Context) (string, error)

type hostTokenFunc struct {
host string
fn TokenFunc
}

type options struct {
inner http.RoundTripper
tokens []hostValue
tokenFiles []hostValue
audiences []hostValue
tokenFns []hostTokenFunc
jwks []hostJWK
}

Expand Down Expand Up @@ -97,11 +110,12 @@ func WithHostTokenFile(host, path string) Option {
return func(o *options) { o.tokenFiles = append(o.tokenFiles, hostValue{host, path}) }
}

// WithHostAudience configures host to be authenticated with an OIDC ID token
// minted for the given audience from the GitHub/Forgejo Actions token endpoint,
// cached for the first 50% of its lifetime and reminted on demand.
func WithHostAudience(host, audience string) Option {
return func(o *options) { o.audiences = append(o.audiences, hostValue{host, audience}) }
// WithHostTokenFunc configures host to be authenticated with a JWT obtained by
// calling fn. The Transport caches the returned token for the first 50% of its
// 'exp' claim's remaining lifetime and re-invokes fn on demand. fn errors and
// tokens missing an 'exp' claim are returned wrapped to the RoundTrip caller.
func WithHostTokenFunc(host string, fn TokenFunc) Option {
return func(o *options) { o.tokenFns = append(o.tokenFns, hostTokenFunc{host, fn}) }
}

// WithHostJWK configures host to be authenticated with a JWT signed locally
Expand Down Expand Up @@ -130,28 +144,28 @@ type jwkConfig struct {

// Transport is an http.RoundTripper that stamps Authorization: Bearer <jwt> on
// requests whose URL host was configured with WithHostToken, WithHostTokenFile,
// WithHostAudience, or WithHostJWK. Any existing Authorization header on a
// WithHostTokenFunc, or WithHostJWK. Any existing Authorization header on a
// configured host is overwritten; requests to other hosts pass through
// untouched.
type Transport struct {
inner http.RoundTripper
// audiences maps a host to the audience minted for it; the factory used on
// a cache miss.
audiences map[string]string
// jwk maps a host to the signing config used to mint a fresh token for
// every request. It is read-only after construction.
jwk map[string]jwkConfig
// tokenFiles maps a host to a file path read on every request. It is
// read-only after construction.
tokenFiles map[string]string
// tokenFns maps a host to the function called to mint a fresh token on a
// cache miss. It is read-only after construction.
tokenFns map[string]TokenFunc

mu sync.Mutex
cache map[string]cacheEntry
}

// NewTransport returns a Transport configured by opts. At least one host must be
// configured. It returns an error if the same host is configured more than once,
// whether via WithHostToken, WithHostTokenFile, WithHostAudience, WithHostJWK,
// whether via WithHostToken, WithHostTokenFile, WithHostTokenFunc, WithHostJWK,
// or a mix of them, or if a WithHostJWK key fails to parse.
func NewTransport(opts ...Option) (*Transport, error) {
o := &options{inner: http.DefaultTransport}
Expand All @@ -161,13 +175,13 @@ func NewTransport(opts ...Option) (*Transport, error) {

t := &Transport{
inner: o.inner,
audiences: make(map[string]string, len(o.audiences)),
jwk: make(map[string]jwkConfig, len(o.jwks)),
tokenFiles: make(map[string]string, len(o.tokenFiles)),
tokenFns: make(map[string]TokenFunc, len(o.tokenFns)),
cache: make(map[string]cacheEntry, len(o.tokens)),
}

seen := make(map[string]bool, len(o.tokens)+len(o.tokenFiles)+len(o.audiences)+len(o.jwks))
seen := make(map[string]bool, len(o.tokens)+len(o.tokenFiles)+len(o.tokenFns)+len(o.jwks))
claim := func(host string) error {
if seen[host] {
return fmt.Errorf("host %q is configured more than once", host)
Expand All @@ -189,11 +203,11 @@ func NewTransport(opts ...Option) (*Transport, error) {
}
t.tokenFiles[hv.host] = hv.value
}
for _, hv := range o.audiences {
if err := claim(hv.host); err != nil {
for _, hf := range o.tokenFns {
if err := claim(hf.host); err != nil {
return nil, err
}
t.audiences[hv.host] = hv.value
t.tokenFns[hf.host] = hf.fn
}
for _, hj := range o.jwks {
if err := claim(hj.host); err != nil {
Expand All @@ -207,7 +221,7 @@ func NewTransport(opts ...Option) (*Transport, error) {
}

if len(seen) == 0 {
return nil, fmt.Errorf("at least one host must be configured with WithHostToken, WithHostTokenFile, WithHostAudience, or WithHostJWK")
return nil, fmt.Errorf("at least one host must be configured with WithHostToken, WithHostTokenFile, WithHostTokenFunc, or WithHostJWK")
}

return t, nil
Expand All @@ -231,7 +245,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
}

// tokenForHost returns the bearer token for host. WithHostJWK and
// WithHostTokenFile hosts get a fresh token on every call; WithHostAudience
// WithHostTokenFile hosts get a fresh token on every call; WithHostTokenFunc
// hosts are minted and cached on a miss. The boolean is false when host was
// not configured.
func (t *Transport) tokenForHost(ctx context.Context, host string) (string, bool, error) {
Expand Down Expand Up @@ -265,13 +279,17 @@ func (t *Transport) tokenForHost(ctx context.Context, host string) (string, bool
return e.token, true, nil
}

audience, ok := t.audiences[host]
fn, ok := t.tokenFns[host]
if !ok {
// Not configured (static hosts are always served from the cache above).
return "", false, nil
}

token, exp, err := actionsoidc.FetchToken(ctx, audience)
token, err := fn(ctx)
if err != nil {
return "", false, err
}
exp, err := tokenExpiry(token)
if err != nil {
return "", false, err
}
Expand All @@ -283,3 +301,21 @@ func (t *Transport) tokenForHost(ctx context.Context, host string) (string, bool
t.cache[host] = cacheEntry{token: token, exp: now.Add(half)}
return token, true, nil
}

// tokenExpiry extracts the 'exp' claim from a compact-serialized JWT without
// verifying its signature. The token was just produced by a caller-supplied
// fn so we trust it; we only need the expiry to schedule re-minting.
func tokenExpiry(token string) (time.Time, error) {
tok, _, err := gojwt.NewParser().ParseUnverified(token, gojwt.MapClaims{})
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse minted token: %w", err)
}
exp, err := tok.Claims.GetExpirationTime()
if err != nil {
return time.Time{}, fmt.Errorf("failed to read exp claim from minted token: %w", err)
}
if exp == nil {
return time.Time{}, fmt.Errorf("minted token has no exp claim")
}
return exp.Time, nil
}
Loading
Loading