Skip to content

Commit 1d10e51

Browse files
committed
Adapt to upstream Spring Security changes
1 parent 5915db0 commit 1d10e51

File tree

3 files changed

+86
-24
lines changed

3 files changed

+86
-24
lines changed

spring-boot-project/spring-boot-actuator-autoconfigure/src/test/java/org/springframework/boot/actuate/autoconfigure/cloudfoundry/servlet/CloudFoundryActuatorAutoConfigurationTests.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import java.util.Arrays;
2020
import java.util.Collection;
21+
import java.util.List;
2122

23+
import jakarta.servlet.Filter;
2224
import org.junit.jupiter.api.Test;
2325

2426
import org.springframework.boot.actuate.autoconfigure.endpoint.EndpointAutoConfiguration;
@@ -43,6 +45,7 @@
4345
import org.springframework.boot.autoconfigure.web.client.RestTemplateAutoConfiguration;
4446
import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration;
4547
import org.springframework.boot.autoconfigure.web.servlet.WebMvcAutoConfiguration;
48+
import org.springframework.boot.test.context.assertj.AssertableWebApplicationContext;
4649
import org.springframework.boot.test.context.runner.WebApplicationContextRunner;
4750
import org.springframework.context.ApplicationContext;
4851
import org.springframework.http.HttpMethod;
@@ -55,6 +58,7 @@
5558
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
5659
import org.springframework.web.client.RestTemplate;
5760
import org.springframework.web.cors.CorsConfiguration;
61+
import org.springframework.web.filter.CompositeFilter;
5862

5963
import static org.assertj.core.api.Assertions.assertThat;
6064
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -173,9 +177,7 @@ void cloudFoundryPathsIgnoredBySpringSecurity() {
173177
this.contextRunner.withBean(TestEndpoint.class, TestEndpoint::new)
174178
.withPropertyValues("VCAP_APPLICATION:---", "vcap.application.application_id:my-app-id")
175179
.run((context) -> {
176-
FilterChainProxy securityFilterChain = (FilterChainProxy) context
177-
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
178-
SecurityFilterChain chain = securityFilterChain.getFilterChains().get(0);
180+
SecurityFilterChain chain = getSecurityFilterChain(context);
179181
assertThat(chain.getFilters()).isEmpty();
180182
MockHttpServletRequest request = new MockHttpServletRequest();
181183
testCloudFoundrySecurity(request, BASE_PATH, chain);
@@ -189,6 +191,27 @@ void cloudFoundryPathsIgnoredBySpringSecurity() {
189191
});
190192
}
191193

194+
private SecurityFilterChain getSecurityFilterChain(AssertableWebApplicationContext context) {
195+
Filter springSecurityFilterChain = context.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN, Filter.class);
196+
FilterChainProxy filterChainProxy = getFilterChainProxy(springSecurityFilterChain);
197+
SecurityFilterChain securityFilterChain = filterChainProxy.getFilterChains().get(0);
198+
return securityFilterChain;
199+
}
200+
201+
private FilterChainProxy getFilterChainProxy(Filter filter) {
202+
if (filter instanceof FilterChainProxy filterChainProxy) {
203+
return filterChainProxy;
204+
}
205+
if (filter instanceof CompositeFilter) {
206+
List<?> filters = (List<?>) ReflectionTestUtils.getField(filter, "filters");
207+
return (FilterChainProxy) filters.stream()
208+
.filter(FilterChainProxy.class::isInstance)
209+
.findFirst()
210+
.orElseThrow();
211+
}
212+
throw new IllegalStateException("No FilterChainProxy found");
213+
}
214+
192215
private static void testCloudFoundrySecurity(MockHttpServletRequest request, String servletPath,
193216
SecurityFilterChain chain) {
194217
request.setServletPath(servletPath);

spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/client/servlet/OAuth2WebSecurityConfigurationTests.java

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.springframework.security.web.SecurityFilterChain;
4949
import org.springframework.test.util.ReflectionTestUtils;
5050
import org.springframework.util.ObjectUtils;
51+
import org.springframework.web.filter.CompositeFilter;
5152

5253
import static org.assertj.core.api.Assertions.assertThat;
5354

@@ -68,7 +69,7 @@ void securityConfigurerConfiguresOAuth2Login() {
6869
.run((context) -> {
6970
ClientRegistrationRepository expected = context.getBean(ClientRegistrationRepository.class);
7071
ClientRegistrationRepository actual = (ClientRegistrationRepository) ReflectionTestUtils.getField(
71-
getFilters(context, OAuth2LoginAuthenticationFilter.class).get(0),
72+
getSecurityFilters(context, OAuth2LoginAuthenticationFilter.class).get(0),
7273
"clientRegistrationRepository");
7374
assertThat(isEqual(expected.findByRegistrationId("first"), actual.findByRegistrationId("first")))
7475
.isTrue();
@@ -85,7 +86,7 @@ void securityConfigurerConfiguresAuthorizationCode() {
8586
.run((context) -> {
8687
ClientRegistrationRepository expected = context.getBean(ClientRegistrationRepository.class);
8788
ClientRegistrationRepository actual = (ClientRegistrationRepository) ReflectionTestUtils.getField(
88-
getFilters(context, OAuth2AuthorizationCodeGrantFilter.class).get(0),
89+
getSecurityFilters(context, OAuth2AuthorizationCodeGrantFilter.class).get(0),
8990
"clientRegistrationRepository");
9091
assertThat(isEqual(expected.findByRegistrationId("first"), actual.findByRegistrationId("first")))
9192
.isTrue();
@@ -98,8 +99,8 @@ void securityConfigurerConfiguresAuthorizationCode() {
9899
void securityConfigurerBacksOffWhenClientRegistrationBeanAbsent() {
99100
this.contextRunner.withUserConfiguration(TestConfig.class, OAuth2WebSecurityConfiguration.class)
100101
.run((context) -> {
101-
assertThat(getFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
102-
assertThat(getFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
102+
assertThat(getSecurityFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
103+
assertThat(getSecurityFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
103104
});
104105
}
105106

@@ -124,8 +125,8 @@ void securityFilterChainConfigBacksOffWhenOtherSecurityFilterChainBeanPresent()
124125
this.contextRunner.withConfiguration(AutoConfigurations.of(WebMvcAutoConfiguration.class))
125126
.withUserConfiguration(TestSecurityFilterChainConfiguration.class, OAuth2WebSecurityConfiguration.class)
126127
.run((context) -> {
127-
assertThat(getFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
128-
assertThat(getFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
128+
assertThat(getSecurityFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
129+
assertThat(getSecurityFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
129130
assertThat(context).getBean(OAuth2AuthorizedClientService.class).isNotNull();
130131
});
131132
}
@@ -137,8 +138,8 @@ void securityFilterChainConfigConditionalOnSecurityFilterChainClass() {
137138
OAuth2WebSecurityConfiguration.class)
138139
.withClassLoader(new FilteredClassLoader(SecurityFilterChain.class))
139140
.run((context) -> {
140-
assertThat(getFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
141-
assertThat(getFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
141+
assertThat(getSecurityFilters(context, OAuth2LoginAuthenticationFilter.class)).isEmpty();
142+
assertThat(getSecurityFilters(context, OAuth2AuthorizationCodeGrantFilter.class)).isEmpty();
142143
});
143144
}
144145

@@ -164,11 +165,29 @@ void authorizedClientRepositoryBeanIsConditionalOnMissingBean() {
164165
});
165166
}
166167

167-
private List<Filter> getFilters(AssertableWebApplicationContext context, Class<? extends Filter> filter) {
168-
FilterChainProxy filterChain = (FilterChainProxy) context.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
169-
List<SecurityFilterChain> filterChains = filterChain.getFilterChains();
170-
List<Filter> filters = filterChains.get(0).getFilters();
171-
return filters.stream().filter(filter::isInstance).toList();
168+
private List<Filter> getSecurityFilters(AssertableWebApplicationContext context, Class<? extends Filter> filter) {
169+
return getSecurityFilterChain(context).getFilters().stream().filter(filter::isInstance).toList();
170+
}
171+
172+
private SecurityFilterChain getSecurityFilterChain(AssertableWebApplicationContext context) {
173+
Filter springSecurityFilterChain = context.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN, Filter.class);
174+
FilterChainProxy filterChainProxy = getFilterChainProxy(springSecurityFilterChain);
175+
SecurityFilterChain securityFilterChain = filterChainProxy.getFilterChains().get(0);
176+
return securityFilterChain;
177+
}
178+
179+
private FilterChainProxy getFilterChainProxy(Filter filter) {
180+
if (filter instanceof FilterChainProxy filterChainProxy) {
181+
return filterChainProxy;
182+
}
183+
if (filter instanceof CompositeFilter) {
184+
List<?> filters = (List<?>) ReflectionTestUtils.getField(filter, "filters");
185+
return (FilterChainProxy) filters.stream()
186+
.filter(FilterChainProxy.class::isInstance)
187+
.findFirst()
188+
.orElseThrow();
189+
}
190+
throw new IllegalStateException("No FilterChainProxy found");
172191
}
173192

174193
private boolean isEqual(ClientRegistration reg1, ClientRegistration reg2) {

spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter;
4747
import org.springframework.security.web.FilterChainProxy;
4848
import org.springframework.security.web.SecurityFilterChain;
49+
import org.springframework.test.util.ReflectionTestUtils;
50+
import org.springframework.web.filter.CompositeFilter;
4951

5052
import static org.assertj.core.api.Assertions.assertThat;
5153
import static org.mockito.Mockito.mock;
@@ -208,15 +210,15 @@ void relyingPartyRegistrationRepositoryShouldBeConditionalOnMissingBean() {
208210
@Test
209211
void samlLoginShouldBeConfigured() {
210212
this.contextRunner.withPropertyValues(getPropertyValues())
211-
.run((context) -> assertThat(hasFilter(context, Saml2WebSsoAuthenticationFilter.class)).isTrue());
213+
.run((context) -> assertThat(hasSecurityFilter(context, Saml2WebSsoAuthenticationFilter.class)).isTrue());
212214
}
213215

214216
@Test
215217
void samlLoginShouldBackOffWhenASecurityFilterChainBeanIsPresent() {
216218
this.contextRunner.withConfiguration(AutoConfigurations.of(WebMvcAutoConfiguration.class))
217219
.withUserConfiguration(TestSecurityFilterChainConfig.class)
218220
.withPropertyValues(getPropertyValues())
219-
.run((context) -> assertThat(hasFilter(context, Saml2WebSsoAuthenticationFilter.class)).isFalse());
221+
.run((context) -> assertThat(hasSecurityFilter(context, Saml2WebSsoAuthenticationFilter.class)).isFalse());
220222
}
221223

222224
@Test
@@ -229,7 +231,7 @@ void samlLoginShouldShouldBeConditionalOnSecurityWebFilterClass() {
229231
@Test
230232
void samlLogoutShouldBeConfigured() {
231233
this.contextRunner.withPropertyValues(getPropertyValues())
232-
.run((context) -> assertThat(hasFilter(context, Saml2LogoutRequestFilter.class)).isTrue());
234+
.run((context) -> assertThat(hasSecurityFilter(context, Saml2LogoutRequestFilter.class)).isTrue());
233235
}
234236

235237
private String[] getPropertyValuesWithoutSigningCredentials(boolean signRequests) {
@@ -323,11 +325,29 @@ private String[] getPropertyValues() {
323325
PREFIX + ".foo.acs.binding=redirect" };
324326
}
325327

326-
private boolean hasFilter(AssertableWebApplicationContext context, Class<? extends Filter> filter) {
327-
FilterChainProxy filterChain = (FilterChainProxy) context.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
328-
List<SecurityFilterChain> filterChains = filterChain.getFilterChains();
329-
List<Filter> filters = filterChains.get(0).getFilters();
330-
return filters.stream().anyMatch(filter::isInstance);
328+
private boolean hasSecurityFilter(AssertableWebApplicationContext context, Class<? extends Filter> filter) {
329+
return getSecurityFilterChain(context).getFilters().stream().anyMatch(filter::isInstance);
330+
}
331+
332+
private SecurityFilterChain getSecurityFilterChain(AssertableWebApplicationContext context) {
333+
Filter springSecurityFilterChain = context.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN, Filter.class);
334+
FilterChainProxy filterChainProxy = getFilterChainProxy(springSecurityFilterChain);
335+
SecurityFilterChain securityFilterChain = filterChainProxy.getFilterChains().get(0);
336+
return securityFilterChain;
337+
}
338+
339+
private FilterChainProxy getFilterChainProxy(Filter filter) {
340+
if (filter instanceof FilterChainProxy filterChainProxy) {
341+
return filterChainProxy;
342+
}
343+
if (filter instanceof CompositeFilter) {
344+
List<?> filters = (List<?>) ReflectionTestUtils.getField(filter, "filters");
345+
return (FilterChainProxy) filters.stream()
346+
.filter(FilterChainProxy.class::isInstance)
347+
.findFirst()
348+
.orElseThrow();
349+
}
350+
throw new IllegalStateException("No FilterChainProxy found");
331351
}
332352

333353
private void setupMockResponse(MockWebServer server, Resource resourceBody) throws Exception {

0 commit comments

Comments
 (0)