Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions pkg/authserver/server/handlers/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,19 @@ func (h *Handler) CallbackHandler(w http.ResponseWriter, req *http.Request) {
return
}

// Exchange code with upstream IDP using the stored PKCE verifier
idpTokens, err := h.upstream.ExchangeCode(ctx, code, pending.UpstreamPKCEVerifier)
// Exchange code and resolve identity in a single atomic operation.
// This ensures OIDC nonce validation cannot be accidentally skipped.
result, err := h.upstream.ExchangeCodeForIdentity(ctx, code, pending.UpstreamPKCEVerifier, pending.UpstreamNonce)
if err != nil {
logger.Errorw("failed to exchange code with upstream IDP",
logger.Errorw("failed to exchange code or resolve identity",
"error", err,
)
h.provider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("failed to exchange authorization code"))
return
}

// Resolve identity from upstream provider
providerSubject, err := h.upstream.ResolveIdentity(ctx, idpTokens, pending.UpstreamNonce)
if err != nil {
logger.Errorw("failed to resolve user identity", "error", err)
h.provider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("failed to verify user identity"))
return
}
idpTokens := result.Tokens
providerSubject := result.Subject

// Get provider ID
providerID := string(h.upstream.Type())
Expand Down
11 changes: 6 additions & 5 deletions pkg/authserver/server/handlers/callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestCallbackHandler_ExchangeCodeFailure(t *testing.T) {

// Configure upstream to fail code exchange
mockUpstream.exchangeErr = assert.AnError
mockUpstream.exchangeTokens = nil
mockUpstream.exchangeResult = nil

// Store a pending authorization
internalState := testInternalState
Expand Down Expand Up @@ -203,8 +203,9 @@ func TestCallbackHandler_IdentityResolutionFailure(t *testing.T) {
t.Parallel()
handler, storState, mockUpstream := handlerTestSetup(t)

// Configure upstream to fail identity resolution
mockUpstream.resolveIdentityErr = assert.AnError
// Configure upstream to fail identity resolution (now part of ExchangeCodeForIdentity)
mockUpstream.exchangeErr = assert.AnError
mockUpstream.exchangeResult = nil

// Store a pending authorization
internalState := testInternalState
Expand All @@ -225,11 +226,11 @@ func TestCallbackHandler_IdentityResolutionFailure(t *testing.T) {

handler.CallbackHandler(rec, req)

// Should fail because identity resolution failed
// Should fail because exchange/identity resolution failed
assert.Equal(t, http.StatusSeeOther, rec.Code)
location := rec.Header().Get("Location")
assert.Contains(t, location, "error=")
assert.Contains(t, location, "failed+to+verify+user+identity")
assert.Contains(t, location, "failed+to+exchange+authorization+code")
}

func TestRoutesIncludeAuthorizeAndCallback(t *testing.T) {
Expand Down
75 changes: 22 additions & 53 deletions pkg/authserver/server/handlers/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,18 @@ const (

// mockIDPProvider implements upstream.OAuth2Provider for testing.
type mockIDPProvider struct {
providerType upstream.ProviderType
authorizationURL string
authURLErr error
exchangeTokens *upstream.Tokens
exchangeErr error
userInfo *upstream.UserInfo
userInfoErr error
refreshTokens *upstream.Tokens
refreshErr error
resolveIdentitySubject string
resolveIdentityErr error
capturedState string
capturedCode string
capturedCodeChallenge string
capturedCodeVerifier string
providerType upstream.ProviderType
authorizationURL string
authURLErr error
exchangeResult *upstream.Identity
exchangeErr error
refreshTokens *upstream.Tokens
refreshErr error
capturedState string
capturedCode string
capturedCodeChallenge string
capturedCodeVerifier string
capturedNonce string
}

// Compile-time interface check.
Expand All @@ -61,23 +58,20 @@ func (m *mockIDPProvider) Type() upstream.ProviderType {
func (m *mockIDPProvider) AuthorizationURL(state, codeChallenge string, _ ...upstream.AuthorizationOption) (string, error) {
m.capturedState = state
m.capturedCodeChallenge = codeChallenge
// Note: We can't easily extract nonce from options since the internal type is private.
// The tests that need to verify nonce will set it directly via the mock setup.
// For the authorize handler test, we capture nonce separately by having the handler
// pass it via WithAdditionalParams and we verify it's non-empty via pending auth storage.
if m.authURLErr != nil {
return "", m.authURLErr
}
return m.authorizationURL + "?state=" + state, nil
}

func (m *mockIDPProvider) ExchangeCode(_ context.Context, code, codeVerifier string) (*upstream.Tokens, error) {
func (m *mockIDPProvider) ExchangeCodeForIdentity(_ context.Context, code, codeVerifier, nonce string) (*upstream.Identity, error) {
m.capturedCode = code
m.capturedCodeVerifier = codeVerifier
m.capturedNonce = nonce
if m.exchangeErr != nil {
return nil, m.exchangeErr
}
return m.exchangeTokens, nil
return m.exchangeResult, nil
}

func (m *mockIDPProvider) RefreshTokens(_ context.Context, _, _ string) (*upstream.Tokens, error) {
Expand All @@ -87,29 +81,6 @@ func (m *mockIDPProvider) RefreshTokens(_ context.Context, _, _ string) (*upstre
return m.refreshTokens, nil
}

func (m *mockIDPProvider) FetchUserInfo(_ context.Context, _ string) (*upstream.UserInfo, error) {
if m.userInfoErr != nil {
return nil, m.userInfoErr
}
return m.userInfo, nil
}

func (m *mockIDPProvider) ResolveIdentity(_ context.Context, _ *upstream.Tokens, _ string) (string, error) {
if m.resolveIdentityErr != nil {
return "", m.resolveIdentityErr
}
// Return explicitly set subject if provided
if m.resolveIdentitySubject != "" {
return m.resolveIdentitySubject, nil
}
// Fallback: return userInfo subject if available
if m.userInfo != nil && m.userInfo.Subject != "" {
return m.userInfo.Subject, nil
}
// Default for tests
return "user-123", nil
}

// testStorageState holds the in-memory state for testing.
type testStorageState struct {
pendingAuths map[string]*storage.PendingAuthorization
Expand Down Expand Up @@ -353,16 +324,14 @@ func handlerTestSetup(t *testing.T) (*Handler, *testStorageState, *mockIDPProvid
mockUpstream := &mockIDPProvider{
providerType: upstream.ProviderTypeOAuth2,
authorizationURL: "https://idp.example.com/authorize",
exchangeTokens: &upstream.Tokens{
AccessToken: "upstream-access-token",
RefreshToken: "upstream-refresh-token",
IDToken: "upstream-id-token",
ExpiresAt: time.Now().Add(time.Hour),
},
userInfo: &upstream.UserInfo{
exchangeResult: &upstream.Identity{
Tokens: &upstream.Tokens{
AccessToken: "upstream-access-token",
RefreshToken: "upstream-refresh-token",
IDToken: "upstream-id-token",
ExpiresAt: time.Now().Add(time.Hour),
},
Subject: "user-123",
Email: "user@example.com",
Name: "Test User",
},
}

Expand Down
25 changes: 9 additions & 16 deletions pkg/authserver/upstream/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@
//
// - Type: Returns the provider type identifier
// - AuthorizationURL: Build redirect URL for user authentication
// - ExchangeCode: Exchange authorization code for tokens
// - ExchangeCodeForIdentity: Exchange authorization code and resolve identity atomically
// - RefreshTokens: Refresh expired tokens (with subject validation for OIDC)
// - ResolveIdentity: Resolve user identity from tokens
// - FetchUserInfo: Fetch user claims
//
// # Type Hierarchy
//
// OAuth2Provider (interface)
// ├── BaseOAuth2Provider (concrete - pure OAuth 2.0, uses UserInfo for identity)
// ├── BaseOAuth2Provider (concrete - pure OAuth 2.0, uses userinfo endpoint for identity)
// └── OIDCProviderImpl (concrete - OIDC with discovery, validates ID tokens for identity)
//
// # Value Objects
//
// - Tokens: Token response from upstream IDP
// - UserInfo: User claims from UserInfo endpoint
// - Identity: Combined tokens + subject from code exchange
// - OAuth2Config: Configuration for OAuth 2.0 providers
//
// # Usage
Expand All @@ -64,28 +62,23 @@
// // Build authorization URL
// authURL, err := provider.AuthorizationURL(state, pkceChallenge)
//
// // After callback, exchange code for tokens
// tokens, err := provider.ExchangeCode(ctx, code, pkceVerifier)
//
// // Resolve user identity
// subject, err := provider.ResolveIdentity(ctx, tokens, "")
// // After callback, exchange code and resolve identity atomically
// result, err := provider.ExchangeCodeForIdentity(ctx, code, pkceVerifier, nonce)
// // result.Tokens contains the upstream tokens
// // result.Subject contains the canonical user identifier
//
// # Extensibility
//
// To add a new IDP type (e.g., SAML), implement the OAuth2Provider interface.
//
// # UserInfo Extensibility
//
// The package supports flexible UserInfo fetching through the OAuth2Provider
// interface's FetchUserInfo method and UserInfoConfig. This enables:
// The package supports flexible userinfo fetching through UserInfoConfig.
// This enables:
//
// - Custom field mapping for non-standard provider responses
// - Additional headers for provider-specific requirements
//
// All OAuth2Provider implementations support FetchUserInfo directly:
//
// userInfo, err := provider.FetchUserInfo(ctx, accessToken)
//
// For custom provider configuration, use UserInfoConfig:
//
// config := &upstream.UserInfoConfig{
Expand Down
44 changes: 7 additions & 37 deletions pkg/authserver/upstream/mocks/mock_provider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading