Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions auth/client/iam/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func (hb HTTPClient) PresentationDefinition(ctx context.Context, presentationDef
if err != nil {
// any OAuth error should be passed
// any other error should result in a 502 Bad Gateway
if oauthErr, ok := err.(oauth.OAuth2Error); ok {
return nil, oauthErr
if errors.As(err, new(oauth.OAuth2Error)) {
return nil, err
}
return nil, errors.Join(ErrBadGateway, err)
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (hb HTTPClient) AccessToken(ctx context.Context, tokenEndpoint string, data
return token, fmt.Errorf("unable to unmarshal OAuth error response: %w", err)
}

return token, oauthError
return token, oauth.RemoteOAuthError{Cause: oauthError}
}

var responseData []byte
Expand Down Expand Up @@ -397,10 +397,10 @@ func (hb HTTPClient) doRequest(ctx context.Context, request *http.Request, targe
if httpErr := core.TestResponseCode(http.StatusOK, response); httpErr != nil {
rse := httpErr.(core.HttpError)
if ok, oauthErr := oauth.TestOAuthErrorCode(rse.ResponseBody, oauth.InvalidScope); ok {
return oauthErr
return oauth.RemoteOAuthError{Cause: oauthErr}
}
if ok, oauthErr := oauth.TestOAuthErrorCode(rse.ResponseBody, oauth.InvalidRequest); ok {
return oauthErr
return oauth.RemoteOAuthError{Cause: oauthErr}
}
return httpErr
}
Expand Down
12 changes: 7 additions & 5 deletions auth/client/iam/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ func TestHTTPClient_PresentationDefinition(t *testing.T) {
_, err := client.PresentationDefinition(ctx, *pdUrl)

require.Error(t, err)
oauthErr, ok := err.(oauth.OAuth2Error)
require.True(t, ok)
var oauthErr oauth.OAuth2Error
require.ErrorAs(t, err, &oauthErr)
assert.Equal(t, oauth.InvalidRequest, oauthErr.Code)
require.ErrorAs(t, err, new(oauth.RemoteOAuthError))
})
}

Expand Down Expand Up @@ -169,10 +170,11 @@ func TestHTTPClient_AccessToken(t *testing.T) {
_, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader)

require.Error(t, err)
// check if the error is an OAuth error
oauthError, ok := err.(oauth.OAuth2Error)
require.True(t, ok)
// check if the error is a remote OAuth error
var oauthError oauth.OAuth2Error
require.ErrorAs(t, err, &oauthError)
assert.Equal(t, oauth.InvalidRequest, oauthError.Code)
require.ErrorAs(t, err, new(oauth.RemoteOAuthError))
})
t.Run("error - generic server error", func(t *testing.T) {
handler := http2.Handler{StatusCode: http.StatusBadGateway, ResponseData: "offline"}
Expand Down
7 changes: 4 additions & 3 deletions auth/client/iam/openid4vp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,10 @@ func TestRelyingParty_RequestRFC021AccessToken(t *testing.T) {
_, err := ctx.client.RequestRFC021AccessToken(context.Background(), subjectClientID, subjectID, ctx.verifierURL.String(), scopes, false, nil)

require.Error(t, err)
oauthError, ok := err.(oauth.OAuth2Error)
require.True(t, ok)
assert.Equal(t, oauth.InvalidScope, oauthError.Code)
var oauthErrResult oauth.OAuth2Error
require.ErrorAs(t, err, &oauthErrResult)
assert.Equal(t, oauth.InvalidScope, oauthErrResult.Code)
require.ErrorAs(t, err, new(oauth.RemoteOAuthError))
})
t.Run("error - failed to get presentation definition", func(t *testing.T) {
ctx := createClientServerTestContext(t)
Expand Down
22 changes: 22 additions & 0 deletions auth/oauth/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,25 @@ func TestOAuthErrorCode(responseBody []byte, code ErrorCode) (bool, OAuth2Error)
}
return oauthErr.Code == code, oauthErr
}

// RemoteOAuthError wraps an OAuth2Error to indicate it was returned by a remote authorization server.
// This allows callers to distinguish between locally-generated and remote OAuth2 errors.
type RemoteOAuthError struct {
// Cause is the underlying OAuth2Error returned by the remote server.
Cause OAuth2Error
}

// Error returns the error message, prefixed with "remote authorization server" to indicate the error origin.
func (e RemoteOAuthError) Error() string {
return "remote authorization server: " + e.Cause.Error()
}

// StatusCode returns the HTTP status code matching the underlying OAuth2 error code.
func (e RemoteOAuthError) StatusCode() int {
return e.Cause.StatusCode()
}

// Unwrap returns the underlying OAuth2Error, allowing errors.As to find it.
func (e RemoteOAuthError) Unwrap() error {
return e.Cause
}
19 changes: 19 additions & 0 deletions auth/oauth/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ func TestError_Error(t *testing.T) {
})
}

func TestRemoteOAuthError(t *testing.T) {
t.Run("Error() prefixes with remote authorization server", func(t *testing.T) {
err := RemoteOAuthError{Cause: OAuth2Error{Code: InvalidRequest, Description: "bad scope"}}
assert.EqualError(t, err, "remote authorization server: invalid_request - bad scope")
})
t.Run("StatusCode() delegates to underlying OAuth2Error", func(t *testing.T) {
err := RemoteOAuthError{Cause: OAuth2Error{Code: ServerError}}
assert.Equal(t, http.StatusInternalServerError, err.StatusCode())
err = RemoteOAuthError{Cause: OAuth2Error{Code: InvalidRequest}}
assert.Equal(t, http.StatusBadRequest, err.StatusCode())
})
t.Run("errors.As finds underlying OAuth2Error", func(t *testing.T) {
wrapped := RemoteOAuthError{Cause: OAuth2Error{Code: InvalidRequest, Description: "bad scope"}}
var oauthErr OAuth2Error
assert.True(t, errors.As(wrapped, &oauthErr))
assert.Equal(t, InvalidRequest, oauthErr.Code)
})
}

func Test_oauth2ErrorWriter_Write(t *testing.T) {
t.Run("user-agent is browser with redirect URI", func(t *testing.T) {
server := echo.New()
Expand Down
Loading