From 5aa5639efa1fcb94b2ff2488da8a90b7c81ebe6e Mon Sep 17 00:00:00 2001
From: Peter-Josef Meisch <pj.meisch@sothawo.com>
Date: Sun, 20 Feb 2022 12:24:04 +0100
Subject: [PATCH] Remove blocking code in SearchDocument processing.

Closes #2025
---
 .../core/ElasticsearchRestTemplate.java       | 19 +++---
 .../core/ReactiveElasticsearchTemplate.java   | 13 ++--
 .../core/document/SearchDocumentResponse.java | 38 ++++++++++--
 ...ggestReactiveTemplateIntegrationTests.java | 60 ++++++++++---------
 4 files changed, 83 insertions(+), 47 deletions(-)

diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java
index 7bc4c2aba..8db157100 100644
--- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java
+++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java
@@ -21,6 +21,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -386,7 +387,7 @@ public <T> SearchHits<T> search(Query query, Class<T> clazz, IndexCoordinates in
 		ReadDocumentCallback<T> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz, index);
 		SearchDocumentResponseCallback<SearchHits<T>> callback = new ReadSearchDocumentResponseCallback<>(clazz, index);
 
-		return callback.doWith(SearchDocumentResponse.from(response, documentCallback::doWith));
+		return callback.doWith(SearchDocumentResponse.from(response, getEntityCreator(documentCallback)));
 	}
 
 	protected <T> SearchHits<T> doSearch(MoreLikeThisQuery query, Class<T> clazz, IndexCoordinates index) {
@@ -410,7 +411,7 @@ public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query
 		ReadDocumentCallback<T> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz, index);
 		SearchDocumentResponseCallback<SearchScrollHits<T>> callback = new ReadSearchScrollDocumentResponseCallback<>(clazz,
 				index);
-		return callback.doWith(SearchDocumentResponse.from(response, documentCallback::doWith));
+		return callback.doWith(SearchDocumentResponse.from(response, getEntityCreator(documentCallback)));
 	}
 
 	@Override
@@ -425,7 +426,7 @@ public <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId, l
 		ReadDocumentCallback<T> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz, index);
 		SearchDocumentResponseCallback<SearchScrollHits<T>> callback = new ReadSearchScrollDocumentResponseCallback<>(clazz,
 				index);
-		return callback.doWith(SearchDocumentResponse.from(response, documentCallback::doWith));
+		return callback.doWith(SearchDocumentResponse.from(response, getEntityCreator(documentCallback)));
 	}
 
 	@Override
@@ -458,7 +459,7 @@ public <T> List<SearchHits<T>> multiSearch(List<? extends Query> queries, Class<
 		SearchDocumentResponseCallback<SearchHits<T>> callback = new ReadSearchDocumentResponseCallback<>(clazz, index);
 		List<SearchHits<T>> res = new ArrayList<>(queries.size());
 		for (int i = 0; i < queries.size(); i++) {
-			res.add(callback.doWith(SearchDocumentResponse.from(items[i].getResponse(), documentCallback::doWith)));
+			res.add(callback.doWith(SearchDocumentResponse.from(items[i].getResponse(), getEntityCreator(documentCallback))));
 		}
 		return res;
 	}
@@ -491,7 +492,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
 					index);
 
 			SearchResponse response = items[i].getResponse();
-			res.add(callback.doWith(SearchDocumentResponse.from(response, documentCallback::doWith)));
+			res.add(callback.doWith(SearchDocumentResponse.from(response, getEntityCreator(documentCallback))));
 		}
 		return res;
 	}
@@ -524,7 +525,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
 					index);
 
 			SearchResponse response = items[i].getResponse();
-			res.add(callback.doWith(SearchDocumentResponse.from(response, documentCallback::doWith)));
+			res.add(callback.doWith(SearchDocumentResponse.from(response, getEntityCreator(documentCallback))));
 		}
 		return res;
 	}
@@ -535,8 +536,12 @@ protected MultiSearchResponse.Item[] getMultiSearchResult(MultiSearchRequest req
 		Assert.isTrue(items.length == request.requests().size(), "Response should has same length with queries");
 		return items;
 	}
-	// endregion
 
+	private <T> SearchDocumentResponse.EntityCreator<T> getEntityCreator(ReadDocumentCallback<T> documentCallback) {
+		return searchDocument -> CompletableFuture.completedFuture(documentCallback.doWith(searchDocument));
+	}
+
+	// endregion
 	// region ClientCallback
 	/**
 	 * Callback interface to be used with {@link #execute(ClientCallback)} for operating directly on
diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java
index 47d703f1a..016f1fbf9 100644
--- a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java
+++ b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java
@@ -23,7 +23,6 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import org.apache.commons.logging.Log;
@@ -820,15 +819,17 @@ private Flux<SearchDocument> doFind(Query query, Class<?> clazz, IndexCoordinate
 		});
 	}
 
-	private Mono<SearchDocumentResponse> doFindForResponse(Query query, Class<?> clazz, IndexCoordinates index) {
+	private <T> Mono<SearchDocumentResponse> doFindForResponse(Query query, Class<?> clazz, IndexCoordinates index) {
 
 		return Mono.defer(() -> {
 			SearchRequest request = requestFactory.searchRequest(query, clazz, index);
 			request = prepareSearchRequest(request, false);
 
 			SearchDocumentCallback<?> documentCallback = new ReadSearchDocumentCallback<>(clazz, index);
-			Function<SearchDocument, Object> entityCreator = searchDocument -> documentCallback.toEntity(searchDocument)
-					.block();
+			// noinspection unchecked
+			SearchDocumentResponse.EntityCreator<T> entityCreator = searchDocument -> ((Mono<T>) documentCallback
+					.toEntity(searchDocument)).toFuture();
+
 			return doFindForResponse(request, entityCreator);
 		});
 	}
@@ -949,8 +950,8 @@ protected Flux<SearchDocument> doFind(SearchRequest request) {
 	 * @param entityCreator
 	 * @return a {@link Mono} emitting the result of the operation converted to s {@link SearchDocumentResponse}.
 	 */
-	protected Mono<SearchDocumentResponse> doFindForResponse(SearchRequest request,
-			Function<SearchDocument, ? extends Object> entityCreator) {
+	protected <T> Mono<SearchDocumentResponse> doFindForResponse(SearchRequest request,
+			SearchDocumentResponse.EntityCreator<T> entityCreator) {
 
 		if (QUERY_LOGGER.isDebugEnabled()) {
 			QUERY_LOGGER.debug(String.format("Executing doFindForResponse: %s", request));
diff --git a/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java b/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java
index abced006b..b23b70459 100644
--- a/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java
+++ b/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java
@@ -17,8 +17,11 @@
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.function.Function;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.text.Text;
@@ -38,13 +41,15 @@
 
 /**
  * This represents the complete search response from Elasticsearch, including the returned documents. Instances must be
- * created with the {@link #from(SearchResponse,Function)} method.
+ * created with the {@link #from(SearchResponse, EntityCreator)} method.
  *
  * @author Peter-Josef Meisch
  * @since 4.0
  */
 public class SearchDocumentResponse {
 
+	private static final Log LOGGER = LogFactory.getLog(SearchDocumentResponse.class);
+
 	private final long totalHits;
 	private final String totalHitsRelation;
 	private final float maxScore;
@@ -102,8 +107,7 @@ public Suggest getSuggest() {
 	 * @param <T> entity type
 	 * @return the SearchDocumentResponse
 	 */
-	public static <T> SearchDocumentResponse from(SearchResponse searchResponse,
-			Function<SearchDocument, T> entityCreator) {
+	public static <T> SearchDocumentResponse from(SearchResponse searchResponse, EntityCreator<T> entityCreator) {
 
 		Assert.notNull(searchResponse, "searchResponse must not be null");
 
@@ -129,7 +133,7 @@ public static <T> SearchDocumentResponse from(SearchResponse searchResponse,
 	 */
 	public static <T> SearchDocumentResponse from(SearchHits searchHits, @Nullable String scrollId,
 			@Nullable Aggregations aggregations, @Nullable org.elasticsearch.search.suggest.Suggest suggestES,
-			Function<SearchDocument, T> entityCreator) {
+			EntityCreator<T> entityCreator) {
 
 		TotalHits responseTotalHits = searchHits.getTotalHits();
 
@@ -160,7 +164,7 @@ public static <T> SearchDocumentResponse from(SearchHits searchHits, @Nullable S
 
 	@Nullable
 	private static <T> Suggest suggestFrom(@Nullable org.elasticsearch.search.suggest.Suggest suggestES,
-			Function<SearchDocument, T> entityCreator) {
+			EntityCreator<T> entityCreator) {
 
 		if (suggestES == null) {
 			return null;
@@ -219,7 +223,19 @@ private static <T> Suggest suggestFrom(@Nullable org.elasticsearch.search.sugges
 					List<CompletionSuggestion.Entry.Option<T>> options = new ArrayList<>();
 					for (org.elasticsearch.search.suggest.completion.CompletionSuggestion.Entry.Option optionES : entryES) {
 						SearchDocument searchDocument = optionES.getHit() != null ? DocumentAdapters.from(optionES.getHit()) : null;
-						T hitEntity = searchDocument != null ? entityCreator.apply(searchDocument) : null;
+
+						T hitEntity = null;
+
+						if (searchDocument != null) {
+							try {
+								hitEntity = entityCreator.apply(searchDocument).get();
+							} catch (Exception e) {
+								if (LOGGER.isWarnEnabled()) {
+									LOGGER.warn("Error creating entity from SearchDocument");
+								}
+							}
+						}
+
 						options.add(new CompletionSuggestion.Entry.Option<T>(textToString(optionES.getText()),
 								textToString(optionES.getHighlighted()), optionES.getScore(), optionES.collateMatch(),
 								optionES.getContexts(), scoreDocFrom(optionES.getDoc()), searchDocument, hitEntity));
@@ -254,4 +270,14 @@ private static ScoreDoc scoreDocFrom(@Nullable org.apache.lucene.search.ScoreDoc
 	private static String textToString(@Nullable Text text) {
 		return text != null ? text.string() : "";
 	}
+
+	/**
+	 * A function to convert a {@link SearchDocument} async into an entity. Asynchronous so that it can be used from the
+	 * imperative and the reactive code.
+	 *
+	 * @param <T> the entity type
+	 */
+	@FunctionalInterface
+	public interface EntityCreator<T> extends Function<SearchDocument, CompletableFuture<T>> {}
+
 }
diff --git a/src/test/java/org/springframework/data/elasticsearch/core/suggest/SuggestReactiveTemplateIntegrationTests.java b/src/test/java/org/springframework/data/elasticsearch/core/suggest/SuggestReactiveTemplateIntegrationTests.java
index 52fba50b6..e693ac773 100644
--- a/src/test/java/org/springframework/data/elasticsearch/core/suggest/SuggestReactiveTemplateIntegrationTests.java
+++ b/src/test/java/org/springframework/data/elasticsearch/core/suggest/SuggestReactiveTemplateIntegrationTests.java
@@ -37,11 +37,11 @@
 import org.springframework.data.annotation.Id;
 import org.springframework.data.elasticsearch.annotations.CompletionField;
 import org.springframework.data.elasticsearch.annotations.Document;
-import org.springframework.data.elasticsearch.core.query.NativeSearchQuery;
-import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
 import org.springframework.data.elasticsearch.core.ReactiveElasticsearchOperations;
 import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
 import org.springframework.data.elasticsearch.core.query.IndexQuery;
+import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
+import org.springframework.data.elasticsearch.core.query.Query;
 import org.springframework.data.elasticsearch.core.suggest.response.CompletionSuggestion;
 import org.springframework.data.elasticsearch.core.suggest.response.Suggest;
 import org.springframework.data.elasticsearch.junit.jupiter.ReactiveElasticsearchRestTemplateConfiguration;
@@ -88,34 +88,38 @@ void shouldDoSomeTest() {
 	@DisplayName("should find suggestions for given prefix completion")
 	void shouldFindSuggestionsForGivenPrefixCompletion() {
 
-		loadCompletionObjectEntities().map(unused -> {
-
-			NativeSearchQuery query = new NativeSearchQueryBuilder().withSuggestBuilder(new SuggestBuilder()
-					.addSuggestion("test-suggest", SuggestBuilders.completionSuggestion("suggest").prefix("m", Fuzziness.AUTO)))
-					.build();
-
-			operations.suggest(query, CompletionEntity.class) //
-					.as(StepVerifier::create) //
-					.assertNext(suggest -> {
-						Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>> suggestion = suggest
-								.getSuggestion("test-suggest");
-						assertThat(suggestion).isNotNull();
-						assertThat(suggestion).isInstanceOf(CompletionSuggestion.class);
-						// noinspection unchecked
-						List<CompletionSuggestion.Entry.Option<CompletionIntegrationTests.AnnotatedCompletionEntity>> options = ((CompletionSuggestion<CompletionIntegrationTests.AnnotatedCompletionEntity>) suggestion)
-								.getEntries().get(0).getOptions();
-						assertThat(options).hasSize(2);
-						assertThat(options.get(0).getText()).isIn("Marchand", "Mohsin");
-						assertThat(options.get(1).getText()).isIn("Marchand", "Mohsin");
-
-					}) //
-					.verifyComplete();
-			return Mono.empty();
-		});
+		loadCompletionObjectEntities() //
+				.flatMap(unused -> {
+					Query query = getSuggestQuery("test-suggest", "suggest", "m");
+					return operations.suggest(query, CompletionEntity.class);
+				}) //
+				.as(StepVerifier::create) //
+				.assertNext(suggest -> {
+					Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>> suggestion = suggest
+							.getSuggestion("test-suggest");
+					assertThat(suggestion).isNotNull();
+					assertThat(suggestion).isInstanceOf(CompletionSuggestion.class);
+					// noinspection unchecked
+					List<CompletionSuggestion.Entry.Option<CompletionIntegrationTests.AnnotatedCompletionEntity>> options = ((CompletionSuggestion<CompletionIntegrationTests.AnnotatedCompletionEntity>) suggestion)
+							.getEntries().get(0).getOptions();
+					assertThat(options).hasSize(2);
+					assertThat(options.get(0).getText()).isIn("Marchand", "Mohsin");
+					assertThat(options.get(1).getText()).isIn("Marchand", "Mohsin");
+				}) //
+				.verifyComplete();
+	}
+
+	protected Query getSuggestQuery(String suggestionName, String fieldName, String prefix) {
+		return new NativeSearchQueryBuilder() //
+				.withSuggestBuilder(new SuggestBuilder() //
+						.addSuggestion(suggestionName, //
+								SuggestBuilders.completionSuggestion(fieldName) //
+										.prefix(prefix, Fuzziness.AUTO))) //
+				.build(); //
 	}
 
 	// region helper functions
-	private Mono<Void> loadCompletionObjectEntities() {
+	private Mono<CompletionEntity> loadCompletionObjectEntities() {
 
 		CompletionEntity rizwan_idrees = new CompletionEntityBuilder("1").name("Rizwan Idrees")
 				.suggest(new String[] { "Rizwan Idrees" }).build();
@@ -128,7 +132,7 @@ private Mono<Void> loadCompletionObjectEntities() {
 		List<CompletionEntity> entities = new ArrayList<>(
 				Arrays.asList(rizwan_idrees, franck_marchand, mohsin_husen, artur_konczak));
 		IndexCoordinates index = IndexCoordinates.of(indexNameProvider.indexName());
-		return operations.saveAll(entities, index).then();
+		return operations.saveAll(entities, index).last();
 	}
 	// endregion