Skip to content

Commit 0fd774e

Browse files
committed
Add allowedOriginPatterns to WebSocketHandlerRegistration
Closes gh-26593
1 parent ec5774e commit 0fd774e

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -54,6 +54,8 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
5454

5555
private final List<String> allowedOrigins = new ArrayList<>();
5656

57+
private final List<String> allowedOriginPatterns = new ArrayList<>();
58+
5759
@Nullable
5860
private SockJsServiceRegistration sockJsServiceRegistration;
5961

@@ -94,6 +96,15 @@ public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins)
9496
return this;
9597
}
9698

99+
@Override
100+
public WebSocketHandlerRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
101+
this.allowedOriginPatterns.clear();
102+
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
103+
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
104+
}
105+
return this;
106+
}
107+
97108
@Override
98109
public SockJsServiceRegistration withSockJS() {
99110
this.sockJsServiceRegistration = new SockJsServiceRegistration();
@@ -108,13 +119,21 @@ public SockJsServiceRegistration withSockJS() {
108119
if (!this.allowedOrigins.isEmpty()) {
109120
this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
110121
}
122+
if (!this.allowedOriginPatterns.isEmpty()) {
123+
this.sockJsServiceRegistration.setAllowedOriginPatterns(
124+
StringUtils.toStringArray(this.allowedOriginPatterns));
125+
}
111126
return this.sockJsServiceRegistration;
112127
}
113128

114129
protected HandshakeInterceptor[] getInterceptors() {
115130
List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1);
116131
interceptors.addAll(this.interceptors);
117-
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
132+
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(this.allowedOrigins);
133+
if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) {
134+
interceptor.setAllowedOriginPatterns(this.allowedOriginPatterns);
135+
}
136+
interceptors.add(interceptor);
118137
return interceptors.toArray(new HandshakeInterceptor[0]);
119138
}
120139

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -63,6 +63,15 @@ public interface WebSocketHandlerRegistration {
6363
*/
6464
WebSocketHandlerRegistration setAllowedOrigins(String... origins);
6565

66+
/**
67+
* A variant of {@link #setAllowedOrigins(String...)} that accepts flexible
68+
* domain patterns, e.g. {@code "https://*.domain1.com"}. Furthermore it
69+
* always sets the {@code Access-Control-Allow-Origin} response header to
70+
* the matched origin and never to {@code "*"}, nor to any other pattern.
71+
* @since 5.3.5
72+
*/
73+
WebSocketHandlerRegistration setAllowedOriginPatterns(String... originPatterns);
74+
6675
/**
6776
* Enable SockJS fallback options.
6877
*/

spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -115,7 +115,10 @@ public void interceptorsWithAllowedOrigins() {
115115
WebSocketHandler handler = new TextWebSocketHandler();
116116
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
117117

118-
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("https://mydomain1.example");
118+
this.registration.addHandler(handler, "/foo")
119+
.addInterceptors(interceptor)
120+
.setAllowedOrigins("https://mydomain1.example")
121+
.setAllowedOriginPatterns("https://*.abc.com");
119122

120123
List<Mapping> mappings = this.registration.getMappings();
121124
assertThat(mappings.size()).isEqualTo(1);
@@ -126,7 +129,10 @@ public void interceptorsWithAllowedOrigins() {
126129
assertThat(mapping.interceptors).isNotNull();
127130
assertThat(mapping.interceptors.length).isEqualTo(2);
128131
assertThat(mapping.interceptors[0]).isEqualTo(interceptor);
129-
assertThat(mapping.interceptors[1].getClass()).isEqualTo(OriginHandshakeInterceptor.class);
132+
133+
OriginHandshakeInterceptor originInterceptor = (OriginHandshakeInterceptor) mapping.interceptors[1];
134+
assertThat(originInterceptor.getAllowedOrigins()).containsExactly("https://mydomain1.example");
135+
assertThat(originInterceptor.getAllowedOriginPatterns()).containsExactly("https://*.abc.com");
130136
}
131137

132138
@Test
@@ -137,6 +143,7 @@ public void interceptorsPassedToSockJsRegistration() {
137143
this.registration.addHandler(handler, "/foo")
138144
.addInterceptors(interceptor)
139145
.setAllowedOrigins("https://mydomain1.example")
146+
.setAllowedOriginPatterns("https://*.abc.com")
140147
.withSockJS();
141148

142149
this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler);
@@ -151,7 +158,10 @@ public void interceptorsPassedToSockJsRegistration() {
151158
assertThat(mapping.sockJsService.getAllowedOrigins().contains("https://mydomain1.example")).isTrue();
152159
List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
153160
assertThat(interceptors.get(0)).isEqualTo(interceptor);
154-
assertThat(interceptors.get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
161+
162+
OriginHandshakeInterceptor originInterceptor = (OriginHandshakeInterceptor) interceptors.get(1);
163+
assertThat(originInterceptor.getAllowedOrigins()).containsExactly("https://mydomain1.example");
164+
assertThat(originInterceptor.getAllowedOriginPatterns()).containsExactly("https://*.abc.com");
155165
}
156166

157167
@Test

0 commit comments

Comments
 (0)