Skip to content

Support token relay clientRegistrationId on properties #3751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,27 @@ forwards the incoming token to outgoing resource requests. The
consumer can be a pure Client (like an SSO application) or a Resource
Server.

////
TODO: support TokenRelay clientRegistrationId
Spring Cloud Gateway Server MVC can forward OAuth2 access tokens downstream to the services
it is proxying using the `TokenRelay` filter.

The `TokenRelay` filter takes one optional parameter, `clientRegistrationId`.
The following example configures a `TokenRelay` filter:

.App.java
.RouteConfiguration.java
[source,java]
----

@Bean
public RouteLocator customRouteLocator(RouteLocatorBuilder builder) {
return builder.routes()
.route("resource", r -> r.path("/resource")
.filters(f -> f.tokenRelay("myregistrationid"))
.uri("http://localhost:9000"))
@Configuration
class RouteConfiguration {

@Bean
public RouterFunction<ServerResponse> gatewayRouterFunctionsTokenRelay() {
return route("resource")
.GET("/resource", http())
.before(uri("https://localhost:9000"))
.filter(tokenRelay("myregistrationid"))
.build();
}
}
----

Expand All @@ -46,19 +48,13 @@ spring:
----

The example above specifies a `clientRegistrationId`, which can be used to obtain and forward an OAuth2 access token for any available `ClientRegistration`.
////

Spring Cloud Gateway Server MVC can forward the OAuth2 access token of the currently authenticated user `oauth2Login()` is used to authenticate the user.
//To add this functionality to the gateway, you can omit the `clientRegistrationId` parameter like this:
To add this functionality to the gateway, you can omit the `clientRegistrationId` parameter like this:

.RouteConfiguration.java
[source,java]
----
import static org.springframework.cloud.gateway.server.mvc.filter.BeforeFilterFunctions.uri;
import static org.springframework.cloud.gateway.server.mvc.filter.TokenRelayFilterFunctions.tokenRelay;
import static org.springframework.cloud.gateway.server.mvc.handler.GatewayRouterFunctions.route;
import static org.springframework.cloud.gateway.server.mvc.handler.HandlerFunctions.http;

@Configuration
class RouteConfiguration {

Expand Down Expand Up @@ -100,9 +96,9 @@ To enable this for Spring Cloud Gateway Server MVC add the following dependencie
- `org.springframework.boot:spring-boot-starter-oauth2-client`

How does it work?
// The filter extracts an OAuth2 access token from the currently authenticated user for the provided `clientRegistrationId`.
// If no `clientRegistrationId` is provided,
The currently authenticated user's own access token (obtained during login) is used and the extracted access token is placed in a request header for the downstream requests.
The filter extracts an OAuth2 access token from the currently authenticated user for the provided `clientRegistrationId`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear this functionality was already present but the documentation for it was commented out?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just the documentation. It seems that, this feature was included as part of the initial filter BUT, is not working. With this PR i'm reenabling the TokenRealy filter through properties, and uncomment the related documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test to confirm it is now working?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, but i don't see an example to test filter functions through properties in gateway-server-mvc to take as base, and my knowledge is limited at this point.

A manual test with properties like these:

spring:
  cloud:
    gateway:
      mvc:
        routes:
          - id: token_relay_test
            uri: https://examplel1.com
            filters:
              - TokenRelay=relay
  security:
    oauth2:
      client:
        provider:
          relay:
            jwk-set-uri: https://localhost:8080/context/path/jwk
            issuer-uri: https://localhost:8080/context/path/issuer
        registration:
          relay:
            provider: relay
            client-authentication-method: client_secret_post
            authorization-grant-type: client_credentials
            client-id: someClientId
            client-secret: someClientSecret

works perfectly using it in a complete spring boot application.

If someone could create a specific test for this, I would be grateful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaimesf I have spent some time and added a test to your PR to test this, see this class https://github.com/spring-cloud/spring-cloud-gateway/pull/3751/files#diff-b6a9d0f37208e4836a1687c1714f1ea8058daca82bd634b88237aef65c70bdf8

Unfortunately to do this I had to make several other changes because Spring Security is on the classpath. In addition once Spring Security was on the classpath it revealed a completely unrelated bug which I documented here #3816

If no `clientRegistrationId` is provided,
the currently authenticated user's own access token (obtained during login) is used and the extracted access token is placed in a request header for the downstream requests.

//For a full working sample see https://github.com/spring-cloud-samples/sample-gateway-oauth2login[this project].

Expand Down
10 changes: 10 additions & 0 deletions spring-cloud-gateway-server-mvc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@
<artifactId>spring-boot-testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-security</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-stream-rabbit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions;
import org.springframework.cloud.gateway.server.mvc.filter.FilterAutoConfiguration;
import org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions;
import org.springframework.cloud.gateway.server.mvc.predicate.PredicateAutoConfiguration;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.filter.AssignableTypeFilter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay()
return tokenRelay(null);
}

@Shortcut
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay(String defaultClientRegistrationId) {
return (request, next) -> {
Authentication principal = (Authentication) request.servletRequest().getUserPrincipal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinUriResolver;
import org.springframework.cloud.gateway.server.mvc.test.LocalServerPortUriResolver;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.TestLoadBalancerConfig;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.Ordered;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpEntity;
Expand All @@ -83,10 +86,12 @@
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport;
import org.springframework.web.servlet.function.HandlerFunction;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.cloud.gateway.server.mvc.filter.AfterFilterFunctions.DedupeStrategy.RETAIN_FIRST;
Expand Down Expand Up @@ -142,8 +147,8 @@
import static org.springframework.web.servlet.function.RequestPredicates.path;

@SuppressWarnings("unchecked")
@SpringBootTest(properties = { "spring.http.client.factory=jdk", "spring.cloud.gateway.function.enabled=false" },
webEnvironment = WebEnvironment.RANDOM_PORT)
@SpringBootTest(properties = { "spring.http.client.factory=jdk", "spring.cloud.gateway.function.enabled=false",
"logging.level.org.springframework.security=TRACE" }, webEnvironment = WebEnvironment.RANDOM_PORT)
@ContextConfiguration(initializers = HttpbinTestcontainers.class)
@ExtendWith(OutputCaptureExtension.class)
public class ServerMvcIntegrationTests {
Expand Down Expand Up @@ -317,7 +322,7 @@ public void setStatusGatewayRouterFunctionWorks() {
.isEqualTo(HttpStatus.TOO_MANY_REQUESTS)
.expectHeader()
.valueEquals("x-status", "201"); // .expectBody(String.class).isEqualTo("Failed
// with 201");
// with 201");
}

@Test
Expand Down Expand Up @@ -1026,7 +1031,8 @@ void logsArtifactDeprecatedWarning(CapturedOutput output) {
@SpringBootConfiguration
@EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
protected static class TestConfiguration {
@Import(PermitAllSecurityConfiguration.class)
protected static class TestConfiguration extends WebMvcConfigurationSupport {

@Bean
StaticPortController staticPortController() {
Expand All @@ -1043,6 +1049,23 @@ EventController eventController() {
return new EventController();
}

// TODO This is needed to work around https://github.com/spring-cloud/spring-cloud-gateway/issues/3816
// which results from Spring Security being on the classpath. Once we can address this issue we should
// remove this bean and no longer extend WebMvcConfigurationSupport in this configuration class
@Bean
@Lazy
@Override
public HandlerMappingIntrospector mvcHandlerMappingIntrospector() {
return new HandlerMappingIntrospector() {
@Override
public Filter createCacheFilter() {
return (request, response, chain) -> {
chain.doFilter(request, response);
};
}
};
}

@Bean
public AsyncProxyManager<String> caffeineProxyManager() {
Caffeine<String, RemoteBucketState> builder = (Caffeine) Caffeine.newBuilder().maximumSize(100);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinUriResolver;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.TestLoadBalancerConfig;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
Expand Down Expand Up @@ -72,6 +74,7 @@ public void routerFunctionsRouteWorks() {
@SpringBootConfiguration
@EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
@Import(PermitAllSecurityConfiguration.class)
protected static class TestConfiguration {

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.http.MediaType;
import org.springframework.test.context.ActiveProfiles;

Expand Down Expand Up @@ -78,6 +80,7 @@ public void testSupplierFunctionWorks() {

@SpringBootConfiguration
@EnableAutoConfiguration
@Import(PermitAllSecurityConfiguration.class)
protected static class TestConfiguration {

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
import org.springframework.cloud.context.refresh.ContextRefresher;
import org.springframework.cloud.gateway.server.mvc.common.MvcUtils;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.TestLoadBalancerConfig;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Import;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod;
import org.springframework.test.context.ActiveProfiles;
Expand Down Expand Up @@ -230,6 +232,7 @@ void refreshWorks(ConfigurableApplicationContext context) {
@SpringBootConfiguration
@EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
@Import(PermitAllSecurityConfiguration.class)
static class Config {

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.http.MediaType;
import org.springframework.test.context.ActiveProfiles;

Expand Down Expand Up @@ -88,6 +90,7 @@ public void testTemplatedStreamWorks() {

@SpringBootConfiguration
@EnableAutoConfiguration
@Import(PermitAllSecurityConfiguration.class)
protected static class TestConfiguration {

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright 2013-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.cloud.gateway.server.mvc.config;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
import org.springframework.cloud.gateway.server.mvc.test.TestAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.context.WebApplicationContext;

import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;
import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

/**
* @author Ryan Baxter
*/
@SpringBootTest(webEnvironment = RANDOM_PORT)
@ContextConfiguration(initializers = HttpbinTestcontainers.class)
@ActiveProfiles("tokenrelay")
public class TokenRelayConfigTests {

@Autowired
private WebApplicationContext context;

private MockMvc mvc;

@BeforeEach
public void setup() {
mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build();
}

@Test
@WithMockUser
public void testTokenRelay() throws Exception {
mvc.perform(get("/bearer"))
.andExpect(status().isOk())
.andExpect(content().json("{\"authenticated\": true, \"token\": \"test\"}"));
}

@EnableAutoConfiguration
@SpringBootConfiguration
@Import(TestAutoConfiguration.class)
public static class TestConfig {

@Bean
public OAuth2AuthorizedClientManager authorizedClientManager() {
OAuth2AuthorizedClientManager manager = mock(OAuth2AuthorizedClientManager.class);
OAuth2AuthorizedClient client = mock(OAuth2AuthorizedClient.class);
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
when(accessToken.getTokenValue()).thenReturn("test");
when(client.getAccessToken()).thenReturn(accessToken);
// The client registration id is set in the token relay filter and must match
when(manager.authorize(argThat(
oAuth2AuthorizeRequest -> "token".equals(oAuth2AuthorizeRequest.getClientRegistrationId()))))
.thenReturn(client);
return manager;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
import org.springframework.cloud.gateway.server.mvc.test.HttpbinUriResolver;
import org.springframework.cloud.gateway.server.mvc.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.gateway.server.mvc.test.TestLoadBalancerConfig;
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -117,6 +119,7 @@ void raisedErrorWhenRemoveJsonAttributes() {
@SpringBootConfiguration
@EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
@Import(PermitAllSecurityConfiguration.class)
protected static class TestConfiguration {

@Bean
Expand Down
Loading
Loading