Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.server.authorization.authentication;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;

import org.jspecify.annotations.Nullable;

import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

/**
* An {@link OAuth2AuthenticationContext} that holds an
* {@link OAuth2TokenExchangeAuthenticationToken} and additional information and is used
* when validating the OAuth 2.0 Token Exchange Grant Request.
*
* @author Rakesh Kumar Singh
* @since 7.1
* @see OAuth2AuthenticationContext
* @see OAuth2TokenExchangeAuthenticationToken
* @see OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OAuth2TokenExchangeAuthenticationContext implements OAuth2AuthenticationContext {

private static final String ACTOR_AUTHORIZATION_ATTR_NAME = OAuth2TokenExchangeAuthenticationContext.class.getName()
.concat(".actorAuthorization");

private final Map<Object, Object> context;

private OAuth2TokenExchangeAuthenticationContext(Map<Object, Object> context) {
this.context = Collections.unmodifiableMap(new HashMap<>(context));
}

@SuppressWarnings("unchecked")
@Override
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}

@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}

/**
* Returns the {@link RegisteredClient registered client}.
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}

/**
* Returns the subject {@link OAuth2Authorization authorization}.
* @return the subject {@link OAuth2Authorization}
*/
public OAuth2Authorization getSubjectAuthorization() {
OAuth2Authorization subjectAuthorization = get(OAuth2Authorization.class);
Assert.notNull(subjectAuthorization, "subjectAuthorization cannot be null");
return subjectAuthorization;
}

/**
* Returns the actor {@link OAuth2Authorization authorization}, or {@code null} if not
* available (impersonation case).
* @return the actor {@link OAuth2Authorization}, or {@code null}
*/
public @Nullable OAuth2Authorization getActorAuthorization() {
return get(ACTOR_AUTHORIZATION_ATTR_NAME);
}

/**
* Constructs a new {@link Builder} with the provided
* {@link OAuth2TokenExchangeAuthenticationToken}.
* @param authentication the {@link OAuth2TokenExchangeAuthenticationToken}
* @return the {@link Builder}
*/
public static Builder with(OAuth2TokenExchangeAuthenticationToken authentication) {
return new Builder(authentication);
}

/**
* A builder for {@link OAuth2TokenExchangeAuthenticationContext}.
*/
public static final class Builder
extends AbstractBuilder<OAuth2TokenExchangeAuthenticationContext, Builder> {

private Builder(OAuth2TokenExchangeAuthenticationToken authentication) {
super(authentication);
}

/**
* Sets the {@link RegisteredClient registered client}.
* @param registeredClient the {@link RegisteredClient}
* @return the {@link Builder} for further configuration
*/
public Builder registeredClient(RegisteredClient registeredClient) {
return put(RegisteredClient.class, registeredClient);
}

/**
* Sets the subject {@link OAuth2Authorization}.
* @param subjectAuthorization the subject {@link OAuth2Authorization}
* @return the {@link Builder} for further configuration
*/
public Builder subjectAuthorization(OAuth2Authorization subjectAuthorization) {
return put(OAuth2Authorization.class, subjectAuthorization);
}

/**
* Sets the actor {@link OAuth2Authorization}, or {@code null} for impersonation.
* @param actorAuthorization the actor {@link OAuth2Authorization}, may be
* {@code null}
* @return the {@link Builder} for further configuration
*/
public Builder actorAuthorization(@Nullable OAuth2Authorization actorAuthorization) {
if (actorAuthorization != null) {
getContext().put(ACTOR_AUTHORIZATION_ATTR_NAME, actorAuthorization);
}
return getThis();
}

/**
* Builds a new {@link OAuth2TokenExchangeAuthenticationContext}.
* @return the {@link OAuth2TokenExchangeAuthenticationContext}
*/
@Override
public OAuth2TokenExchangeAuthenticationContext build() {
Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null");
Assert.notNull(get(OAuth2Authorization.class), "subjectAuthorization cannot be null");
return new OAuth2TokenExchangeAuthenticationContext(getContext());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -86,6 +87,8 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti

private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;

private Consumer<OAuth2TokenExchangeAuthenticationContext> authenticationValidator = new OAuth2TokenExchangeAuthenticationValidator();

/**
* Constructs an {@code OAuth2TokenExchangeAuthenticationProvider} using the provided
* parameters.
Expand Down Expand Up @@ -204,12 +207,20 @@ else if (authorizedActorClaims != null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}

OAuth2TokenExchangeAuthenticationContext authenticationContext = OAuth2TokenExchangeAuthenticationContext
.with(tokenExchangeAuthentication)
.registeredClient(registeredClient)
.subjectAuthorization(subjectAuthorization)
.actorAuthorization(actorAuthorization)
.build();
this.authenticationValidator.accept(authenticationContext);

Set<String> authorizedScopes = Collections.emptySet();
if (!CollectionUtils.isEmpty(tokenExchangeAuthentication.getScopes())) {
authorizedScopes = validateRequestedScopes(registeredClient, tokenExchangeAuthentication.getScopes());
authorizedScopes = new LinkedHashSet<>(tokenExchangeAuthentication.getScopes());
}
else if (!CollectionUtils.isEmpty(subjectAuthorization.getAuthorizedScopes())) {
authorizedScopes = validateRequestedScopes(registeredClient, subjectAuthorization.getAuthorizedScopes());
authorizedScopes = new LinkedHashSet<>(subjectAuthorization.getAuthorizedScopes());
}

// Verify the DPoP Proof (if available)
Expand Down Expand Up @@ -285,16 +296,6 @@ private static boolean isValidTokenType(String tokenType, OAuth2Authorization.To
&& OAuth2TokenFormat.SELF_CONTAINED.getValue().equals(tokenFormat);
}

private static Set<String> validateRequestedScopes(RegisteredClient registeredClient, Set<String> requestedScopes) {
for (String requestedScope : requestedScopes) {
if (!registeredClient.getScopes().contains(requestedScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE);
}
}

return new LinkedHashSet<>(requestedScopes);
}

private static void validateClaims(Map<String, Object> expectedClaims, @Nullable Map<String, Object> actualClaims,
String... claimNames) {
if (actualClaims == null) {
Expand Down Expand Up @@ -342,4 +343,25 @@ public boolean supports(Class<?> authentication) {
return OAuth2TokenExchangeAuthenticationToken.class.isAssignableFrom(authentication);
}

/**
* Sets the {@code Consumer} providing access to the
* {@link OAuth2TokenExchangeAuthenticationContext} and is responsible for validating
* specific OAuth 2.0 Token Exchange Grant Request parameters associated in the
* {@link OAuth2TokenExchangeAuthenticationToken}. The default authentication validator
* is {@link OAuth2TokenExchangeAuthenticationValidator}.
*
* <p>
* <b>NOTE:</b> The authentication validator MUST throw
* {@link org.springframework.security.oauth2.core.OAuth2AuthenticationException} if
* validation fails.
* @param authenticationValidator the {@code Consumer} providing access to the
* {@link OAuth2TokenExchangeAuthenticationContext} and is responsible for validating
* specific OAuth 2.0 Token Exchange Grant Request parameters
*/
public void setAuthenticationValidator(
Consumer<OAuth2TokenExchangeAuthenticationContext> authenticationValidator) {
Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
this.authenticationValidator = authenticationValidator;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.server.authorization.authentication;

import java.util.Set;
import java.util.function.Consumer;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.core.log.LogMessage;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.CollectionUtils;

/**
* A {@code Consumer} providing access to the
* {@link OAuth2TokenExchangeAuthenticationContext} containing an
* {@link OAuth2TokenExchangeAuthenticationToken} and is the default
* {@link OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer)
* authentication validator} used for validating specific OAuth 2.0 Token Exchange Grant
* Request parameters.
*
* <p>
* The default implementation validates
* {@link OAuth2TokenExchangeAuthenticationToken#getScopes()}. If validation fails, an
* {@link OAuth2AuthenticationException} is thrown.
*
* @author Rakesh Kumar Singh
* @since 7.1
* @see OAuth2TokenExchangeAuthenticationContext
* @see OAuth2TokenExchangeAuthenticationToken
* @see OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer)
*/
public final class OAuth2TokenExchangeAuthenticationValidator
implements Consumer<OAuth2TokenExchangeAuthenticationContext> {

private static final Log LOGGER = LogFactory.getLog(OAuth2TokenExchangeAuthenticationValidator.class);

/**
* The default validator for
* {@link OAuth2TokenExchangeAuthenticationToken#getScopes()}.
*/
public static final Consumer<OAuth2TokenExchangeAuthenticationContext> DEFAULT_SCOPE_VALIDATOR = OAuth2TokenExchangeAuthenticationValidator::validateScope;

private final Consumer<OAuth2TokenExchangeAuthenticationContext> authenticationValidator = DEFAULT_SCOPE_VALIDATOR;

@Override
public void accept(OAuth2TokenExchangeAuthenticationContext authenticationContext) {
this.authenticationValidator.accept(authenticationContext);
}

private static void validateScope(OAuth2TokenExchangeAuthenticationContext authenticationContext) {
OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication = authenticationContext.getAuthentication();
RegisteredClient registeredClient = authenticationContext.getRegisteredClient();
OAuth2Authorization subjectAuthorization = authenticationContext.getSubjectAuthorization();

Set<String> requestedScopes = tokenExchangeAuthentication.getScopes();
if (CollectionUtils.isEmpty(requestedScopes)) {
requestedScopes = subjectAuthorization.getAuthorizedScopes();
}

Set<String> allowedScopes = registeredClient.getScopes();
if (!requestedScopes.isEmpty() && !allowedScopes.containsAll(requestedScopes)) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(LogMessage.format(
"Invalid request: requested scope is not allowed" + " for registered client '%s'",
registeredClient.getId()));
}
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,66 @@ public void tearDown() {
AuthorizationServerContextHolder.resetContext();
}

@Test
public void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authenticationProvider.setAuthenticationValidator(null))
.withMessage("authenticationValidator cannot be null");
// @formatter:on
}

@Test
public void authenticateWhenCustomAuthenticationValidatorThenUsed() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.build();
OAuth2TokenExchangeAuthenticationToken authentication = createDelegationRequest(registeredClient);
OAuth2Authorization subjectAuthorization = TestOAuth2Authorizations.authorization(registeredClient)
.token(createAccessToken(SUBJECT_TOKEN))
.build();
OAuth2Authorization actorAuthorization = TestOAuth2Authorizations.authorization(registeredClient)
.token(createAccessToken(ACTOR_TOKEN))
.build();
given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class)))
.willReturn(subjectAuthorization, actorAuthorization);
OAuth2AccessToken accessToken = createAccessToken("token-value");
given(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).willReturn(accessToken);

Consumer<OAuth2TokenExchangeAuthenticationContext> customValidator = mock(Consumer.class);
this.authenticationProvider.setAuthenticationValidator(customValidator);
this.authenticationProvider.authenticate(authentication);

verify(customValidator).accept(any(OAuth2TokenExchangeAuthenticationContext.class));
}

@Test
public void authenticateWhenCustomAuthenticationValidatorThrowsThenPropagated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.build();
OAuth2TokenExchangeAuthenticationToken authentication = createDelegationRequest(registeredClient);
OAuth2Authorization subjectAuthorization = TestOAuth2Authorizations.authorization(registeredClient)
.token(createAccessToken(SUBJECT_TOKEN))
.build();
OAuth2Authorization actorAuthorization = TestOAuth2Authorizations.authorization(registeredClient)
.token(createAccessToken(ACTOR_TOKEN))
.build();
given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class)))
.willReturn(subjectAuthorization, actorAuthorization);

this.authenticationProvider
.setAuthenticationValidator((ctx) -> { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); });
// @formatter:off
assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.extracting(OAuth2AuthenticationException::getError)
.extracting(OAuth2Error::getErrorCode)
.isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
// @formatter:on
verifyNoInteractions(this.tokenGenerator);
}

@Test
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down