From b317def78d8e316527b4edabf58ce8556c1edfc1 Mon Sep 17 00:00:00 2001 From: Vincent Potucek Date: Fri, 13 Jun 2025 23:07:29 +0200 Subject: [PATCH 1/2] use try-with-resources statement in ResourceRegionHttpMessageConverter Signed-off-by: Vincent Potucek --- .../ResourceRegionHttpMessageConverter.java | 34 ++-- ...sourceRegionHttpMessageConverterTests.java | 173 ++++++++++++++++++ .../messaging/StompSubProtocolHandler.java | 2 - .../StompSubProtocolHandlerTests.java | 122 ++++++++++++ 4 files changed, 308 insertions(+), 23 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java index 216fdc29361b..70ff64a165da 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java @@ -38,6 +38,9 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.StreamUtils; +import static java.lang.Math.min; +import static org.springframework.util.StreamUtils.copyRange; + /** * Implementation of {@link HttpMessageConverter} that can write a single * {@link ResourceRegion} or Collections of {@link ResourceRegion ResourceRegions}. @@ -163,29 +166,18 @@ private boolean supportsRepeatableWrites(ResourceRegion region) { protected void writeResourceRegion(ResourceRegion region, HttpOutputMessage outputMessage) throws IOException { Assert.notNull(region, "ResourceRegion must not be null"); - HttpHeaders responseHeaders = outputMessage.getHeaders(); - long start = region.getPosition(); - long end = start + region.getCount() - 1; - long resourceLength = region.getResource().contentLength(); - end = Math.min(end, resourceLength - 1); - long rangeLength = end - start + 1; + var start = region.getPosition(); + var resourceLength = region.getResource().contentLength(); + var end = min(start + region.getCount() - 1, resourceLength - 1); + var responseHeaders = outputMessage.getHeaders(); + responseHeaders.setContentLength(end - start + 1); responseHeaders.add("Content-Range", "bytes " + start + '-' + end + '/' + resourceLength); - responseHeaders.setContentLength(rangeLength); - InputStream in = region.getResource().getInputStream(); - // We cannot use try-with-resources here for the InputStream, since we have - // custom handling of the close() method in a finally-block. - try { - StreamUtils.copyRange(in, outputMessage.getBody(), start, end); + try (var in = region.getResource().getInputStream()) { + copyRange(in, outputMessage.getBody(), start, end); } - finally { - try { - in.close(); - } - catch (IOException ex) { - // ignore - } + catch (IOException ignored) { } } @@ -227,14 +219,14 @@ private void writeResourceRegionCollection(Collection resourceRe println(out); } long resourceLength = region.getResource().contentLength(); - end = Math.min(end, resourceLength - inputStreamPosition - 1); + end = min(end, resourceLength - inputStreamPosition - 1); print(out, "Content-Range: bytes " + region.getPosition() + '-' + (region.getPosition() + region.getCount() - 1) + '/' + resourceLength); println(out); println(out); // Printing content - StreamUtils.copyRange(in, out, start, end); + copyRange(in, out, start, end); inputStreamPosition += (end + 1); } } diff --git a/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java index e538f9013cf0..f5eb2be69b18 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java @@ -17,6 +17,7 @@ package org.springframework.http.converter; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -27,6 +28,7 @@ import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; import org.springframework.core.io.support.ResourceRegion; import org.springframework.http.HttpHeaders; @@ -36,6 +38,7 @@ import org.springframework.web.testfixture.http.MockHttpOutputMessage; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -198,4 +201,174 @@ public void applicationOctetStreamDefaultContentType() throws Exception { assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).isEqualTo("Spring"); } + @Test + void shouldNotWriteForUnsupportedType() { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Object unsupportedBody = new Object(); + + assertThatThrownBy(() -> converter.write(unsupportedBody, null, outputMessage)) + .isInstanceOfAny(ClassCastException.class, HttpMessageNotWritableException.class); + } + + @Test + void shouldGetDefaultContentTypeForResourceRegion() { + Resource resource = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = new ResourceRegion(resource, 0, 10); + + MediaType contentType = converter.getDefaultContentType(region); + assertThat(contentType).isEqualTo(MediaType.TEXT_PLAIN); + } + + @Test + void shouldGetDefaultOctetStreamContentTypeForUnknownResource() { + Resource resource = mock(Resource.class); + given(resource.getFilename()).willReturn("unknown.dat"); + ResourceRegion region = new ResourceRegion(resource, 0, 10); + + MediaType contentType = converter.getDefaultContentType(region); + assertThat(contentType).isEqualTo(MediaType.APPLICATION_OCTET_STREAM); + } + + @Test + void shouldSupportRepeatableWritesForNonInputStreamResource() { + Resource resource = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = new ResourceRegion(resource, 0, 10); + + assertThat(converter.supportsRepeatableWrites(region)).isTrue(); + } + + @Test + void shouldNotSupportRepeatableWritesForInputStreamResource() { + Resource resource = mock(InputStreamResource.class); + ResourceRegion region = new ResourceRegion(resource, 0, 10); + + assertThat(converter.supportsRepeatableWrites(region)).isFalse(); + } + + @Test + void shouldHandleIOExceptionWhenWritingRegion() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + given(resource.contentLength()).willReturn(10L); + given(resource.getInputStream()).willThrow(new IOException("Simulated error")); + ResourceRegion region = new ResourceRegion(resource, 0, 5); + + // Should not throw exception + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + // Verify Content-Range header is set correctly + assertThat(outputMessage.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE)) + .isEqualTo("bytes 0-4/10"); + + // Verify no content was written due to the IOException + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).isEmpty(); + } + @Test + void shouldHandleIOExceptionWhenWritingRegionCollection() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + given(resource.contentLength()).willReturn(10L); + given(resource.getInputStream()).willThrow(new IOException("Simulated error")); + ResourceRegion region = new ResourceRegion(resource, 0, 5); + List regions = Collections.singletonList(region); + + // Should not throw exception + converter.write(regions, MediaType.TEXT_PLAIN, outputMessage); + + assertThat(outputMessage.getHeaders().getContentType().toString()) + .isEqualTo("text/plain"); + } + + @Test + void shouldHandleNullResourceRegion() { + assertThatThrownBy(() -> converter.write(null, null, new MockHttpOutputMessage())) + .isInstanceOf(NullPointerException.class); + } + + @Test + void shouldHandleInvalidRangeBeyondResourceLength() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = new ResourceRegion(resource, 35, 10); // Goes beyond resource length + + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + assertThat(outputMessage.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE)) + .isEqualTo("bytes 35-38/39"); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).hasSize(4); + } + + @Test + void shouldHandleZeroLengthResourceRegion() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = new ResourceRegion(resource, 5, 0); + + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + assertThat(outputMessage.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE)) + .isEqualTo("bytes 5-4/39"); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).isEmpty(); + } + + @Test + void shouldHandleMultipleResourcesInCollection() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource1 = new ClassPathResource("byterangeresource.txt", getClass()); + Resource resource2 = new ClassPathResource("byterangeresource.txt", getClass()); + List regions = List.of( + new ResourceRegion(resource1, 0, 5), // "Spring" is 6 bytes (0-5) + new ResourceRegion(resource2, 7, 8) // "Framework" is 8 bytes (7-14) + ); + + converter.write(regions, MediaType.TEXT_PLAIN, outputMessage); + + String content = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + + // Verify multipart structure + assertThat(content).contains("Content-Type: text/plain"); + assertThat(content).contains("Content-Range: bytes 7-14/39"); + + // Verify partial content (note the ranges only include parts of the words) + assertThat(content).contains("Sprin"); // First 5 bytes of "Spring" (0-4) + assertThat(content).contains("Framewor"); // First 7 bytes of "Framework" (7-13) + } + @Test + void shouldHandleNullContentType() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = new ResourceRegion(resource, 0, 5); + + converter.write(region, null, outputMessage); + + assertThat(outputMessage.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_PLAIN); + } + + @Test + void shouldHandleUnreadableResource() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + given(resource.contentLength()).willReturn(10L); + given(resource.getInputStream()).willThrow(new IOException("Cannot read resource")); + ResourceRegion region = new ResourceRegion(resource, 0, 5); + + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + assertThat(outputMessage.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE)) + .isEqualTo("bytes 0-4/10"); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).isEmpty(); + } + + @Test + void shouldHandleCanWriteWithNullType() { + assertThat(converter.canWrite(null, null, null)).isFalse(); + } + + @Test + void shouldHandleCanWriteWithNonParameterizedType() { + assertThat(converter.canWrite(ResourceRegion.class, null, null)).isTrue(); + assertThat(converter.canWrite(String.class, null, null)).isFalse(); + } + + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index e221b66fe728..7d054103160c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -399,8 +399,6 @@ private void sendErrorMessage(WebSocketSession session, Throwable error) { headerAccessor.setMessage(error.getMessage()); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); - // We cannot use try-with-resources here for the WebSocketSession, since we have - // custom handling of the close() method in a finally-block. try { session.sendMessage(new TextMessage(bytes)); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 83de8d40865c..bddd29ce25b1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -531,7 +531,129 @@ public boolean send(Message message, long timeout) { verify(runnable, times(1)).run(); } + @Test + void handleMessageFromClientWithBinaryStompFrame() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + byte[] payload = new StompEncoder().encode(headers.getMessageHeaders(), EMPTY_PAYLOAD); + BinaryMessage binaryMessage = new BinaryMessage(payload); + + this.protocolHandler.afterSessionStarted(this.session, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, binaryMessage, this.channel); + + verify(this.channel).send(this.messageCaptor.capture()); + Message actual = this.messageCaptor.getValue(); + assertThat(actual).isNotNull(); + assertThat(StompHeaderAccessor.wrap(actual).getCommand()).isEqualTo(StompCommand.CONNECT); + } + + @Test + void handleMessageFromClientWithPartialStompFrame() { + TextMessage partialMessage = new TextMessage("CONNECT\naccept-version:1.2\n\n"); + + this.protocolHandler.afterSessionStarted(this.session, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, partialMessage, this.channel); + + verifyNoInteractions(this.channel); + assertThat(this.session.getSentMessages()).isEmpty(); + } + + @Test + void handleMessageFromClientWithInvalidStompFrame() { + TextMessage invalidMessage = new TextMessage("INVALID_COMMAND\n\n\0"); + + this.protocolHandler.afterSessionStarted(this.session, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, invalidMessage, this.channel); + + verifyNoInteractions(this.channel); + assertThat(this.session.getSentMessages()).hasSize(1); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertThat(actual.getPayload()).startsWith("ERROR"); + } + + @Test + void handleMessageToClientWithEmptyPayload() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); + headers.setDestination("/topic/foo"); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); + + this.protocolHandler.handleMessageToClient(this.session, message); + + assertThat(this.session.getSentMessages()).hasSize(1); + WebSocketMessage textMessage = this.session.getSentMessages().get(0); + assertThat(textMessage.getPayload()).isEqualTo("MESSAGE\ndestination:/topic/foo\ncontent-length:0\n\n\u0000"); + } + + @Test + void handleMessageToClientWithLargePayload() { + byte[] largePayload = new byte[1024 * 64]; // 64KB + Arrays.fill(largePayload, (byte) 'A'); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); + headers.setDestination("/topic/foo"); + headers.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM); + Message message = MessageBuilder.createMessage(largePayload, headers.getMessageHeaders()); + + this.protocolHandler.handleMessageToClient(this.session, message); + + assertThat(this.session.getSentMessages()).hasSize(1); + WebSocketMessage webSocketMessage = this.session.getSentMessages().get(0); + assertThat(webSocketMessage).isInstanceOf(BinaryMessage.class); + } + + @Test + void handleMessageToClientWithErrorFrame() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); + headers.setMessage("Test error"); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); + + this.protocolHandler.handleMessageToClient(this.session, message); + + assertThat(this.session.getSentMessages()).hasSize(1); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertThat(actual.getPayload()).startsWith("ERROR\nmessage:Test error"); + + // Verify session was closed + assertThat(this.session.isOpen()).isFalse(); + assertThat(this.session.getCloseStatus()).isEqualTo(CloseStatus.PROTOCOL_ERROR); + } + + @Test + void handleMessageToClientWithHeartbeat() { + StompHeaderAccessor headers = StompHeaderAccessor.createForHeartbeat(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); + this.protocolHandler.handleMessageToClient(this.session, message); + + assertThat(this.session.getSentMessages()).hasSize(1); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertThat(actual.getPayload()).isEqualTo("\n"); + }@Test + void sessionAttributesArePreserved() { + this.session.getAttributes().put("key", "value"); + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); + this.protocolHandler.afterSessionStarted(this.session, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verify(this.channel).send(this.messageCaptor.capture()); + Message actual = this.messageCaptor.getValue(); + assertThat(SimpMessageHeaderAccessor.getSessionAttributes(actual.getHeaders())) + .containsEntry("key", "value"); + } + + @Test + void immutableMessageHandling() { + // Create immutable message + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setImmutable(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); + + this.protocolHandler.handleMessageToClient(this.session, message); + + assertThat(this.session.getSentMessages()).hasSize(1); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertThat(actual.getPayload()).contains("SEND"); + } private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { private UniqueUser(String name) { From 547a0dca9f1a9f0f15d656c7942b853f4cc4567e Mon Sep 17 00:00:00 2001 From: Vincent Potucek Date: Sat, 14 Jun 2025 11:05:19 +0200 Subject: [PATCH 2/2] remove false comment about try-with-resources statement in StompSubProtocolHandler Signed-off-by: Vincent Potucek --- .../ResourceRegionHttpMessageConverter.java | 34 +++-- .../StompSubProtocolHandlerTests.java | 122 ------------------ 2 files changed, 21 insertions(+), 135 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java index 70ff64a165da..216fdc29361b 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java @@ -38,9 +38,6 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.StreamUtils; -import static java.lang.Math.min; -import static org.springframework.util.StreamUtils.copyRange; - /** * Implementation of {@link HttpMessageConverter} that can write a single * {@link ResourceRegion} or Collections of {@link ResourceRegion ResourceRegions}. @@ -166,18 +163,29 @@ private boolean supportsRepeatableWrites(ResourceRegion region) { protected void writeResourceRegion(ResourceRegion region, HttpOutputMessage outputMessage) throws IOException { Assert.notNull(region, "ResourceRegion must not be null"); + HttpHeaders responseHeaders = outputMessage.getHeaders(); - var start = region.getPosition(); - var resourceLength = region.getResource().contentLength(); - var end = min(start + region.getCount() - 1, resourceLength - 1); - var responseHeaders = outputMessage.getHeaders(); - responseHeaders.setContentLength(end - start + 1); + long start = region.getPosition(); + long end = start + region.getCount() - 1; + long resourceLength = region.getResource().contentLength(); + end = Math.min(end, resourceLength - 1); + long rangeLength = end - start + 1; responseHeaders.add("Content-Range", "bytes " + start + '-' + end + '/' + resourceLength); + responseHeaders.setContentLength(rangeLength); - try (var in = region.getResource().getInputStream()) { - copyRange(in, outputMessage.getBody(), start, end); + InputStream in = region.getResource().getInputStream(); + // We cannot use try-with-resources here for the InputStream, since we have + // custom handling of the close() method in a finally-block. + try { + StreamUtils.copyRange(in, outputMessage.getBody(), start, end); } - catch (IOException ignored) { + finally { + try { + in.close(); + } + catch (IOException ex) { + // ignore + } } } @@ -219,14 +227,14 @@ private void writeResourceRegionCollection(Collection resourceRe println(out); } long resourceLength = region.getResource().contentLength(); - end = min(end, resourceLength - inputStreamPosition - 1); + end = Math.min(end, resourceLength - inputStreamPosition - 1); print(out, "Content-Range: bytes " + region.getPosition() + '-' + (region.getPosition() + region.getCount() - 1) + '/' + resourceLength); println(out); println(out); // Printing content - copyRange(in, out, start, end); + StreamUtils.copyRange(in, out, start, end); inputStreamPosition += (end + 1); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index bddd29ce25b1..83de8d40865c 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -531,129 +531,7 @@ public boolean send(Message message, long timeout) { verify(runnable, times(1)).run(); } - @Test - void handleMessageFromClientWithBinaryStompFrame() { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - byte[] payload = new StompEncoder().encode(headers.getMessageHeaders(), EMPTY_PAYLOAD); - BinaryMessage binaryMessage = new BinaryMessage(payload); - - this.protocolHandler.afterSessionStarted(this.session, this.channel); - this.protocolHandler.handleMessageFromClient(this.session, binaryMessage, this.channel); - - verify(this.channel).send(this.messageCaptor.capture()); - Message actual = this.messageCaptor.getValue(); - assertThat(actual).isNotNull(); - assertThat(StompHeaderAccessor.wrap(actual).getCommand()).isEqualTo(StompCommand.CONNECT); - } - - @Test - void handleMessageFromClientWithPartialStompFrame() { - TextMessage partialMessage = new TextMessage("CONNECT\naccept-version:1.2\n\n"); - - this.protocolHandler.afterSessionStarted(this.session, this.channel); - this.protocolHandler.handleMessageFromClient(this.session, partialMessage, this.channel); - - verifyNoInteractions(this.channel); - assertThat(this.session.getSentMessages()).isEmpty(); - } - - @Test - void handleMessageFromClientWithInvalidStompFrame() { - TextMessage invalidMessage = new TextMessage("INVALID_COMMAND\n\n\0"); - - this.protocolHandler.afterSessionStarted(this.session, this.channel); - this.protocolHandler.handleMessageFromClient(this.session, invalidMessage, this.channel); - - verifyNoInteractions(this.channel); - assertThat(this.session.getSentMessages()).hasSize(1); - TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertThat(actual.getPayload()).startsWith("ERROR"); - } - - @Test - void handleMessageToClientWithEmptyPayload() { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); - headers.setDestination("/topic/foo"); - Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); - - this.protocolHandler.handleMessageToClient(this.session, message); - - assertThat(this.session.getSentMessages()).hasSize(1); - WebSocketMessage textMessage = this.session.getSentMessages().get(0); - assertThat(textMessage.getPayload()).isEqualTo("MESSAGE\ndestination:/topic/foo\ncontent-length:0\n\n\u0000"); - } - - @Test - void handleMessageToClientWithLargePayload() { - byte[] largePayload = new byte[1024 * 64]; // 64KB - Arrays.fill(largePayload, (byte) 'A'); - - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); - headers.setDestination("/topic/foo"); - headers.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM); - Message message = MessageBuilder.createMessage(largePayload, headers.getMessageHeaders()); - - this.protocolHandler.handleMessageToClient(this.session, message); - - assertThat(this.session.getSentMessages()).hasSize(1); - WebSocketMessage webSocketMessage = this.session.getSentMessages().get(0); - assertThat(webSocketMessage).isInstanceOf(BinaryMessage.class); - } - - @Test - void handleMessageToClientWithErrorFrame() { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); - headers.setMessage("Test error"); - Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); - - this.protocolHandler.handleMessageToClient(this.session, message); - - assertThat(this.session.getSentMessages()).hasSize(1); - TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertThat(actual.getPayload()).startsWith("ERROR\nmessage:Test error"); - - // Verify session was closed - assertThat(this.session.isOpen()).isFalse(); - assertThat(this.session.getCloseStatus()).isEqualTo(CloseStatus.PROTOCOL_ERROR); - } - - @Test - void handleMessageToClientWithHeartbeat() { - StompHeaderAccessor headers = StompHeaderAccessor.createForHeartbeat(); - Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); - this.protocolHandler.handleMessageToClient(this.session, message); - - assertThat(this.session.getSentMessages()).hasSize(1); - TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertThat(actual.getPayload()).isEqualTo("\n"); - }@Test - void sessionAttributesArePreserved() { - this.session.getAttributes().put("key", "value"); - - TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); - this.protocolHandler.afterSessionStarted(this.session, this.channel); - this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); - - verify(this.channel).send(this.messageCaptor.capture()); - Message actual = this.messageCaptor.getValue(); - assertThat(SimpMessageHeaderAccessor.getSessionAttributes(actual.getHeaders())) - .containsEntry("key", "value"); - } - - @Test - void immutableMessageHandling() { - // Create immutable message - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); - headers.setImmutable(); - Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); - - this.protocolHandler.handleMessageToClient(this.session, message); - - assertThat(this.session.getSentMessages()).hasSize(1); - TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertThat(actual.getPayload()).contains("SEND"); - } private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { private UniqueUser(String name) {