34
34
import org .springframework .security .core .AuthenticationException ;
35
35
import org .springframework .security .core .context .SecurityContext ;
36
36
import org .springframework .security .core .context .SecurityContextHolder ;
37
+ import org .springframework .security .oauth2 .core .ClientAuthenticationMethod ;
37
38
import org .springframework .security .oauth2 .core .OAuth2AuthenticationException ;
38
39
import org .springframework .security .oauth2 .core .OAuth2Error ;
39
40
import org .springframework .security .oauth2 .core .OAuth2ErrorCodes ;
53
54
import org .springframework .security .web .authentication .AuthenticationSuccessHandler ;
54
55
import org .springframework .security .web .authentication .DelegatingAuthenticationConverter ;
55
56
import org .springframework .security .web .authentication .WebAuthenticationDetailsSource ;
57
+ import org .springframework .security .web .authentication .www .BasicAuthenticationEntryPoint ;
56
58
import org .springframework .security .web .util .matcher .RequestMatcher ;
57
59
import org .springframework .util .Assert ;
58
60
import org .springframework .web .filter .OncePerRequestFilter ;
@@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
90
92
91
93
private final AuthenticationDetailsSource <HttpServletRequest , ?> authenticationDetailsSource = new WebAuthenticationDetailsSource ();
92
94
95
+ private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint ();
96
+
93
97
private AuthenticationConverter authenticationConverter ;
94
98
95
99
private AuthenticationSuccessHandler authenticationSuccessHandler = this ::onAuthenticationSuccess ;
@@ -110,6 +114,7 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana
110
114
Assert .notNull (requestMatcher , "requestMatcher cannot be null" );
111
115
this .authenticationManager = authenticationManager ;
112
116
this .requestMatcher = requestMatcher ;
117
+ this .basicAuthenticationEntryPoint .setRealmName ("default" );
113
118
// @formatter:off
114
119
this .authenticationConverter = new DelegatingAuthenticationConverter (
115
120
Arrays .asList (
@@ -130,8 +135,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
130
135
return ;
131
136
}
132
137
138
+ Authentication authenticationRequest = null ;
133
139
try {
134
- Authentication authenticationRequest = this .authenticationConverter .convert (request );
140
+ authenticationRequest = this .authenticationConverter .convert (request );
135
141
if (authenticationRequest instanceof AbstractAuthenticationToken authenticationToken ) {
136
142
authenticationToken .setDetails (this .authenticationDetailsSource .buildDetails (request ));
137
143
}
@@ -147,7 +153,14 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
147
153
if (this .logger .isTraceEnabled ()) {
148
154
this .logger .trace (LogMessage .format ("Client authentication failed: %s" , ex .getError ()), ex );
149
155
}
150
- this .authenticationFailureHandler .onAuthenticationFailure (request , response , ex );
156
+ if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication ) {
157
+ this .authenticationFailureHandler .onAuthenticationFailure (request , response ,
158
+ new OAuth2ClientAuthenticationException (ex .getError (), ex , clientAuthentication ));
159
+ }
160
+ else {
161
+ this .authenticationFailureHandler .onAuthenticationFailure (request , response , ex );
162
+ }
163
+
151
164
}
152
165
}
153
166
@@ -199,21 +212,21 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp
199
212
}
200
213
201
214
private void onAuthenticationFailure (HttpServletRequest request , HttpServletResponse response ,
202
- AuthenticationException exception ) throws IOException {
215
+ AuthenticationException authenticationException ) throws IOException {
203
216
204
217
SecurityContextHolder .clearContext ();
205
218
206
- // TODO
207
- // The authorization server MAY return an HTTP 401 (Unauthorized) status code
208
- // to indicate which HTTP authentication schemes are supported.
209
- // If the client attempted to authenticate via the "Authorization" request header
210
- // field,
211
- // the authorization server MUST respond with an HTTP 401 (Unauthorized) status
212
- // code and
213
- // include the "WWW-Authenticate" response header field
214
- // matching the authentication scheme used by the client.
215
-
216
- OAuth2Error error = ((OAuth2AuthenticationException ) exception ).getError ();
219
+ if ( authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException ) {
220
+ OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
221
+ . getClientAuthentication ();
222
+ if ( ClientAuthenticationMethod . CLIENT_SECRET_BASIC
223
+ . equals ( clientAuthentication . getClientAuthenticationMethod ())) {
224
+ this . basicAuthenticationEntryPoint . commence ( request , response , authenticationException );
225
+ return ;
226
+ }
227
+ }
228
+
229
+ OAuth2Error error = ((OAuth2AuthenticationException ) authenticationException ).getError ();
217
230
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse (response );
218
231
if (OAuth2ErrorCodes .INVALID_CLIENT .equals (error .getErrorCode ())) {
219
232
httpResponse .setStatusCode (HttpStatus .UNAUTHORIZED );
@@ -248,4 +261,21 @@ private static void validateClientIdentifier(Authentication authentication) {
248
261
}
249
262
}
250
263
264
+ private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
265
+
266
+ private final OAuth2ClientAuthenticationToken clientAuthentication ;
267
+
268
+ private OAuth2ClientAuthenticationException (OAuth2Error error , Throwable cause ,
269
+ OAuth2ClientAuthenticationToken clientAuthentication ) {
270
+ super (error , cause );
271
+ Assert .notNull (clientAuthentication , "clientAuthentication cannot be null" );
272
+ this .clientAuthentication = clientAuthentication ;
273
+ }
274
+
275
+ private OAuth2ClientAuthenticationToken getClientAuthentication () {
276
+ return this .clientAuthentication ;
277
+ }
278
+
279
+ }
280
+
251
281
}
0 commit comments