Skip to content

Commit c78a65d

Browse files
authored
Fix: OpenAI: map exceptions for streaming and moderation models (langchain4j#2986)
## Issue `OpenAiStreamingChatModel`, `OpenAiStreamingLanguageModel` and `OpenAiModerationModel` are not mapping low-level `HttpException` to common exceptions (e.g., `AuthenticationException` and `RateLimitException`) like all other `OpenAi*Model` do. ## Change Map exceptions for `OpenAiStreamingChatModel`, `OpenAiStreamingLanguageModel` and `OpenAiModerationModel` ## General checklist - [ ] There are no breaking changes - [x] I have added unit and/or integration tests for my change - [x] The tests cover both positive and negative cases - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
1 parent 7f3c69a commit c78a65d

File tree

8 files changed

+161
-22
lines changed

8 files changed

+161
-22
lines changed

langchain4j-core/src/main/java/dev/langchain4j/internal/ExceptionMapper.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ default <T> T withExceptionMapper(Callable<T> action) {
3232
}
3333
}
3434

35-
RuntimeException mapException(Exception e);
35+
RuntimeException mapException(Throwable t);
3636

3737
class DefaultExceptionMapper implements ExceptionMapper {
3838

3939
@Override
40-
public RuntimeException mapException(Exception e) {
41-
Throwable rootCause = findRoot(e);
40+
public RuntimeException mapException(Throwable t) {
41+
Throwable rootCause = findRoot(t);
4242

4343
if (rootCause instanceof HttpException httpException) {
4444
return mapHttpStatusCode(httpException, httpException.statusCode());
@@ -48,29 +48,29 @@ public RuntimeException mapException(Exception e) {
4848
return new UnresolvedModelServerException(rootCause);
4949
}
5050

51-
return e instanceof RuntimeException re ? re : new LangChain4jException(e);
51+
return t instanceof RuntimeException re ? re : new LangChain4jException(t);
5252
}
5353

54-
protected RuntimeException mapHttpStatusCode(Exception rootException, int httpStatusCode) {
54+
protected RuntimeException mapHttpStatusCode(Throwable cause, int httpStatusCode) {
5555
if (httpStatusCode >= 500 && httpStatusCode < 600) {
56-
return new InternalServerException(rootException);
56+
return new InternalServerException(cause);
5757
}
5858
if (httpStatusCode == 401 || httpStatusCode == 403) {
59-
return new AuthenticationException(rootException);
59+
return new AuthenticationException(cause);
6060
}
6161
if (httpStatusCode == 404) {
62-
return new ModelNotFoundException(rootException);
62+
return new ModelNotFoundException(cause);
6363
}
6464
if (httpStatusCode == 408) {
65-
return new TimeoutException(rootException);
65+
return new TimeoutException(cause);
6666
}
6767
if (httpStatusCode == 429) {
68-
return new RateLimitException(rootException);
68+
return new RateLimitException(cause);
6969
}
7070
if (httpStatusCode >= 400 && httpStatusCode < 500) {
71-
return new InvalidRequestException(rootException);
71+
return new InvalidRequestException(cause);
7272
}
73-
return rootException instanceof RuntimeException re ? re : new LangChain4jException(rootException);
73+
return cause instanceof RuntimeException re ? re : new LangChain4jException(cause);
7474
}
7575

7676
private static Throwable findRoot(Throwable e) {

langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaExceptionMapper.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ class JlamaExceptionMapper extends ExceptionMapper.DefaultExceptionMapper {
1414
private JlamaExceptionMapper() { }
1515

1616
@Override
17-
public RuntimeException mapException(final Exception e) {
18-
if (e instanceof IOException && e.getMessage().startsWith(JLAMA_IOEXCEPTION_START_MESSAGE)) {
19-
String httpStatusCode = e.getMessage().substring(JLAMA_IOEXCEPTION_START_MESSAGE.length(), JLAMA_IOEXCEPTION_START_MESSAGE.length() + 3);
17+
public RuntimeException mapException(Throwable t) {
18+
if (t instanceof IOException && t.getMessage().startsWith(JLAMA_IOEXCEPTION_START_MESSAGE)) {
19+
String httpStatusCode = t.getMessage().substring(JLAMA_IOEXCEPTION_START_MESSAGE.length(), JLAMA_IOEXCEPTION_START_MESSAGE.length() + 3);
2020
try {
21-
return mapHttpStatusCode(e, Integer.parseInt(httpStatusCode));
21+
return mapHttpStatusCode(t, Integer.parseInt(httpStatusCode));
2222
} catch (NumberFormatException nfe) {
2323
// ignore
2424
}
2525
}
2626

27-
return new InternalServerException(e);
27+
return new InternalServerException(t);
2828
}
2929
}

langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiModerationModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Map;
2121

2222
import static dev.langchain4j.internal.RetryUtils.withRetry;
23+
import static dev.langchain4j.internal.RetryUtils.withRetryMappingExceptions;
2324
import static dev.langchain4j.internal.Utils.getOrDefault;
2425
import static dev.langchain4j.model.openai.internal.OpenAiUtils.DEFAULT_OPENAI_URL;
2526
import static dev.langchain4j.model.openai.internal.OpenAiUtils.DEFAULT_USER_AGENT;
@@ -71,7 +72,7 @@ private Response<Moderation> moderateInternal(List<String> inputs) {
7172
.input(inputs)
7273
.build();
7374

74-
ModerationResponse response = withRetry(() -> client.moderation(request).execute(), maxRetries);
75+
ModerationResponse response = withRetryMappingExceptions(() -> client.moderation(request).execute(), maxRetries);
7576

7677
int i = 0;
7778
for (ModerationResult moderationResult : response.results()) {

langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dev.langchain4j.model.openai;
22

33
import dev.langchain4j.http.client.HttpClientBuilder;
4+
import dev.langchain4j.internal.ExceptionMapper;
45
import dev.langchain4j.model.ModelProvider;
56
import dev.langchain4j.model.StreamingResponseHandler;
67
import dev.langchain4j.model.chat.StreamingChatModel;
@@ -136,7 +137,9 @@ public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler
136137
ChatResponse chatResponse = openAiResponseBuilder.build();
137138
handler.onCompleteResponse(chatResponse);
138139
})
139-
.onError(handler::onError)
140+
.onError(throwable -> {
141+
handler.onError(ExceptionMapper.DEFAULT.mapException(throwable));
142+
})
140143
.execute();
141144
}
142145

langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingLanguageModel.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dev.langchain4j.model.openai;
22

33
import dev.langchain4j.http.client.HttpClientBuilder;
4+
import dev.langchain4j.internal.ExceptionMapper;
45
import dev.langchain4j.model.StreamingResponseHandler;
56
import dev.langchain4j.model.chat.response.ChatResponse;
67
import dev.langchain4j.model.language.StreamingLanguageModel;
@@ -88,7 +89,9 @@ public void generate(String prompt, StreamingResponseHandler<String> handler) {
8889
chatResponse.metadata().finishReason()
8990
));
9091
})
91-
.onError(handler::onError)
92+
.onError(throwable -> {
93+
handler.onError(ExceptionMapper.DEFAULT.mapException(throwable));
94+
})
9295
.execute();
9396
}
9497

langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelErrorsTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class OpenAiChatModelErrorsTest {
2727

2828
private static final MockOpenai MOCK = new MockOpenai();
2929

30-
public static final Duration TIMEOUT = Duration.ofMillis(200);
30+
public static final Duration TIMEOUT = Duration.ofMillis(300);
3131

3232
ChatModel model = OpenAiChatModel.builder()
3333
.baseUrl(MOCK.baseUrl())
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package dev.langchain4j.model.openai;
2+
3+
import dev.langchain4j.exception.AuthenticationException;
4+
import dev.langchain4j.exception.HttpException;
5+
import dev.langchain4j.exception.InternalServerException;
6+
import dev.langchain4j.exception.InvalidRequestException;
7+
import dev.langchain4j.exception.LangChain4jException;
8+
import dev.langchain4j.exception.ModelNotFoundException;
9+
import dev.langchain4j.exception.RateLimitException;
10+
import dev.langchain4j.exception.TimeoutException;
11+
import dev.langchain4j.model.chat.StreamingChatModel;
12+
import dev.langchain4j.model.chat.response.ChatResponse;
13+
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
14+
import io.ktor.http.HttpStatusCode;
15+
import me.kpavlov.aimocks.openai.MockOpenai;
16+
import org.junit.jupiter.api.Test;
17+
import org.junit.jupiter.params.ParameterizedTest;
18+
import org.junit.jupiter.params.provider.Arguments;
19+
import org.junit.jupiter.params.provider.MethodSource;
20+
21+
import java.net.http.HttpTimeoutException;
22+
import java.time.Duration;
23+
import java.util.concurrent.CompletableFuture;
24+
import java.util.stream.Stream;
25+
26+
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
27+
import static java.util.concurrent.TimeUnit.SECONDS;
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
30+
class OpenAiStreamingChatModelErrorsTest {
31+
32+
private static final MockOpenai MOCK = new MockOpenai();
33+
34+
public static final Duration TIMEOUT = Duration.ofMillis(300);
35+
36+
StreamingChatModel model = OpenAiStreamingChatModel.builder()
37+
.baseUrl(MOCK.baseUrl())
38+
.modelName(GPT_4_O_MINI)
39+
.timeout(TIMEOUT)
40+
.logRequests(true)
41+
.logResponses(true)
42+
.build();
43+
44+
public static Stream<Arguments> errors() {
45+
return Stream.of(
46+
Arguments.of(400, InvalidRequestException.class),
47+
Arguments.of(401, AuthenticationException.class),
48+
Arguments.of(403, AuthenticationException.class),
49+
Arguments.of(404, ModelNotFoundException.class),
50+
Arguments.of(413, InvalidRequestException.class),
51+
Arguments.of(429, RateLimitException.class),
52+
Arguments.of(500, InternalServerException.class),
53+
Arguments.of(503, InternalServerException.class));
54+
}
55+
56+
@ParameterizedTest
57+
@MethodSource("errors")
58+
void should_handle_error_responses(int httpStatusCode, Class<LangChain4jException> exception) throws Exception {
59+
60+
// given
61+
final var question = "Return error: " + httpStatusCode;
62+
MOCK.completion(req -> req.userMessageContains(question)).respondsError(res -> {
63+
res.setHttpStatus(HttpStatusCode.Companion.fromValue(httpStatusCode));
64+
res.setBody("");
65+
});
66+
67+
CompletableFuture<Throwable> futureError = new CompletableFuture<>();
68+
StreamingChatResponseHandler handler = new ErrorHandler(futureError);
69+
70+
// when
71+
model.chat(question, handler);
72+
73+
// then
74+
Throwable error = futureError.get(30, SECONDS);
75+
76+
assertThat(error)
77+
.isExactlyInstanceOf(exception)
78+
.satisfies(ex -> assertThat(((HttpException) ex.getCause()).statusCode())
79+
.as("statusCode")
80+
.isEqualTo(httpStatusCode));
81+
}
82+
83+
@Test
84+
void should_handle_timeout() throws Exception {
85+
86+
// given
87+
final var question = "Simulate timeout";
88+
MOCK.completion(req -> req.userMessageContains(question)).respondsError(res -> {
89+
res.delayMillis(TIMEOUT.plusMillis(100).toMillis());
90+
res.setHttpStatus(HttpStatusCode.Companion.getNoContent());
91+
res.setBody("");
92+
});
93+
94+
CompletableFuture<Throwable> futureError = new CompletableFuture<>();
95+
StreamingChatResponseHandler handler = new ErrorHandler(futureError);
96+
97+
// when
98+
model.chat(question, handler);
99+
100+
// then
101+
Throwable error = futureError.get(30, SECONDS);
102+
103+
assertThat(error)
104+
.isExactlyInstanceOf(TimeoutException.class)
105+
.hasRootCauseExactlyInstanceOf(HttpTimeoutException.class);
106+
}
107+
108+
class ErrorHandler implements StreamingChatResponseHandler {
109+
110+
private final CompletableFuture<Throwable> futureError;
111+
112+
ErrorHandler(CompletableFuture<Throwable> futureError) {
113+
this.futureError = futureError;
114+
}
115+
116+
@Override
117+
public void onPartialResponse(String partialResponse) {
118+
futureError.completeExceptionally(new RuntimeException("onPartialResponse must not be called"));
119+
}
120+
121+
@Override
122+
public void onCompleteResponse(ChatResponse completeResponse) {
123+
futureError.completeExceptionally(new RuntimeException("onCompleteResponse must not be called"));
124+
}
125+
126+
@Override
127+
public void onError(Throwable error) {
128+
futureError.complete(error);
129+
}
130+
}
131+
}

langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelListenerIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
44
import static java.util.Collections.singletonList;
55

6+
import dev.langchain4j.exception.AuthenticationException;
67
import dev.langchain4j.exception.HttpException;
78
import dev.langchain4j.model.chat.StreamingChatModel;
89
import dev.langchain4j.model.chat.common.AbstractStreamingChatModelListenerIT;
@@ -46,6 +47,6 @@ protected StreamingChatModel createFailingModel(ChatModelListener listener) {
4647

4748
@Override
4849
protected Class<? extends Exception> expectedExceptionClass() {
49-
return HttpException.class;
50+
return AuthenticationException.class;
5051
}
5152
}

0 commit comments

Comments
 (0)