diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 29186d0a19a3..d3d007bea96c 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -477,7 +477,14 @@ public RequestBodySpec body(T body, ParameterizedTypeReference bodyType) @Override public RequestBodySpec body(StreamingHttpOutputMessage.Body body) { - this.body = request -> body.writeTo(request.getBody()); + this.body = request -> { + if (request instanceof StreamingHttpOutputMessage streamingMessage) { + streamingMessage.setBody(body); + } + else { + body.writeTo(request.getBody()); + } + }; return this; } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index f3b05812da7f..e82cf02f6fda 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -16,6 +16,7 @@ package org.springframework.web.client; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -559,6 +560,27 @@ void postUserAsJsonWithJsonView(ClientHttpRequestFactory requestFactory) { }); } + @ParameterizedRestClientTest + void postStreamingMessageBody(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(200)); + + ResponseEntity result = this.restClient.post() + .uri("/streaming/body") + .body(new ByteArrayInputStream("test-data".getBytes(UTF_8))::transferTo) + .retrieve() + .toBodilessEntity(); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/streaming/body"); + assertThat(request.getBody().readUtf8()).isEqualTo("test-data"); + }); + } + @ParameterizedRestClientTest // gh-31361 public void postForm(ClientHttpRequestFactory requestFactory) { startServer(requestFactory);