48
48
import org .springframework .security .web .SecurityFilterChain ;
49
49
import org .springframework .test .util .ReflectionTestUtils ;
50
50
import org .springframework .util .ObjectUtils ;
51
+ import org .springframework .web .filter .CompositeFilter ;
51
52
52
53
import static org .assertj .core .api .Assertions .assertThat ;
53
54
@@ -68,7 +69,7 @@ void securityConfigurerConfiguresOAuth2Login() {
68
69
.run ((context ) -> {
69
70
ClientRegistrationRepository expected = context .getBean (ClientRegistrationRepository .class );
70
71
ClientRegistrationRepository actual = (ClientRegistrationRepository ) ReflectionTestUtils .getField (
71
- getFilters (context , OAuth2LoginAuthenticationFilter .class ).get (0 ),
72
+ getSecurityFilters (context , OAuth2LoginAuthenticationFilter .class ).get (0 ),
72
73
"clientRegistrationRepository" );
73
74
assertThat (isEqual (expected .findByRegistrationId ("first" ), actual .findByRegistrationId ("first" )))
74
75
.isTrue ();
@@ -85,7 +86,7 @@ void securityConfigurerConfiguresAuthorizationCode() {
85
86
.run ((context ) -> {
86
87
ClientRegistrationRepository expected = context .getBean (ClientRegistrationRepository .class );
87
88
ClientRegistrationRepository actual = (ClientRegistrationRepository ) ReflectionTestUtils .getField (
88
- getFilters (context , OAuth2AuthorizationCodeGrantFilter .class ).get (0 ),
89
+ getSecurityFilters (context , OAuth2AuthorizationCodeGrantFilter .class ).get (0 ),
89
90
"clientRegistrationRepository" );
90
91
assertThat (isEqual (expected .findByRegistrationId ("first" ), actual .findByRegistrationId ("first" )))
91
92
.isTrue ();
@@ -98,8 +99,8 @@ void securityConfigurerConfiguresAuthorizationCode() {
98
99
void securityConfigurerBacksOffWhenClientRegistrationBeanAbsent () {
99
100
this .contextRunner .withUserConfiguration (TestConfig .class , OAuth2WebSecurityConfiguration .class )
100
101
.run ((context ) -> {
101
- assertThat (getFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
102
- assertThat (getFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
102
+ assertThat (getSecurityFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
103
+ assertThat (getSecurityFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
103
104
});
104
105
}
105
106
@@ -124,8 +125,8 @@ void securityFilterChainConfigBacksOffWhenOtherSecurityFilterChainBeanPresent()
124
125
this .contextRunner .withConfiguration (AutoConfigurations .of (WebMvcAutoConfiguration .class ))
125
126
.withUserConfiguration (TestSecurityFilterChainConfiguration .class , OAuth2WebSecurityConfiguration .class )
126
127
.run ((context ) -> {
127
- assertThat (getFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
128
- assertThat (getFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
128
+ assertThat (getSecurityFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
129
+ assertThat (getSecurityFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
129
130
assertThat (context ).getBean (OAuth2AuthorizedClientService .class ).isNotNull ();
130
131
});
131
132
}
@@ -137,8 +138,8 @@ void securityFilterChainConfigConditionalOnSecurityFilterChainClass() {
137
138
OAuth2WebSecurityConfiguration .class )
138
139
.withClassLoader (new FilteredClassLoader (SecurityFilterChain .class ))
139
140
.run ((context ) -> {
140
- assertThat (getFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
141
- assertThat (getFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
141
+ assertThat (getSecurityFilters (context , OAuth2LoginAuthenticationFilter .class )).isEmpty ();
142
+ assertThat (getSecurityFilters (context , OAuth2AuthorizationCodeGrantFilter .class )).isEmpty ();
142
143
});
143
144
}
144
145
@@ -164,11 +165,29 @@ void authorizedClientRepositoryBeanIsConditionalOnMissingBean() {
164
165
});
165
166
}
166
167
167
- private List <Filter > getFilters (AssertableWebApplicationContext context , Class <? extends Filter > filter ) {
168
- FilterChainProxy filterChain = (FilterChainProxy ) context .getBean (BeanIds .SPRING_SECURITY_FILTER_CHAIN );
169
- List <SecurityFilterChain > filterChains = filterChain .getFilterChains ();
170
- List <Filter > filters = filterChains .get (0 ).getFilters ();
171
- return filters .stream ().filter (filter ::isInstance ).toList ();
168
+ private List <Filter > getSecurityFilters (AssertableWebApplicationContext context , Class <? extends Filter > filter ) {
169
+ return getSecurityFilterChain (context ).getFilters ().stream ().filter (filter ::isInstance ).toList ();
170
+ }
171
+
172
+ private SecurityFilterChain getSecurityFilterChain (AssertableWebApplicationContext context ) {
173
+ Filter springSecurityFilterChain = context .getBean (BeanIds .SPRING_SECURITY_FILTER_CHAIN , Filter .class );
174
+ FilterChainProxy filterChainProxy = getFilterChainProxy (springSecurityFilterChain );
175
+ SecurityFilterChain securityFilterChain = filterChainProxy .getFilterChains ().get (0 );
176
+ return securityFilterChain ;
177
+ }
178
+
179
+ private FilterChainProxy getFilterChainProxy (Filter filter ) {
180
+ if (filter instanceof FilterChainProxy filterChainProxy ) {
181
+ return filterChainProxy ;
182
+ }
183
+ if (filter instanceof CompositeFilter ) {
184
+ List <?> filters = (List <?>) ReflectionTestUtils .getField (filter , "filters" );
185
+ return (FilterChainProxy ) filters .stream ()
186
+ .filter (FilterChainProxy .class ::isInstance )
187
+ .findFirst ()
188
+ .orElseThrow ();
189
+ }
190
+ throw new IllegalStateException ("No FilterChainProxy found" );
172
191
}
173
192
174
193
private boolean isEqual (ClientRegistration reg1 , ClientRegistration reg2 ) {
0 commit comments