Skip to content

Commit 10e1e13

Browse files
ThomasVitalemarkpollack
authored andcommitted
Add observability support for Ollama
- Improve ITs to reuse a single container across tests
1 parent 3bef5c1 commit 10e1e13

File tree

27 files changed

+696
-221
lines changed

27 files changed

+696
-221
lines changed

models/spring-ai-ollama/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@
7474
<scope>test</scope>
7575
</dependency>
7676

77+
<dependency>
78+
<groupId>io.micrometer</groupId>
79+
<artifactId>micrometer-observation-test</artifactId>
80+
<scope>test</scope>
81+
</dependency>
82+
7783
<dependency>
7884
<groupId>org.testcontainers</groupId>
7985
<artifactId>junit-jupiter</artifactId>

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 156 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,22 @@
2121
import java.util.Map;
2222
import java.util.Set;
2323

24+
import io.micrometer.observation.Observation;
25+
import io.micrometer.observation.ObservationRegistry;
26+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2427
import org.springframework.ai.chat.messages.AssistantMessage;
2528
import org.springframework.ai.chat.messages.SystemMessage;
2629
import org.springframework.ai.chat.messages.ToolResponseMessage;
2730
import org.springframework.ai.chat.messages.UserMessage;
2831
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2932
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
30-
import org.springframework.ai.chat.model.AbstractToolCallSupport;
31-
import org.springframework.ai.chat.model.ChatModel;
32-
import org.springframework.ai.chat.model.ChatResponse;
33-
import org.springframework.ai.chat.model.Generation;
33+
import org.springframework.ai.chat.model.*;
34+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
35+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
36+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
37+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
3438
import org.springframework.ai.chat.prompt.ChatOptions;
39+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
3540
import org.springframework.ai.chat.prompt.Prompt;
3641
import org.springframework.ai.model.ModelOptionsUtils;
3742
import org.springframework.ai.model.function.FunctionCallback;
@@ -64,6 +69,8 @@
6469
*/
6570
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
6671

72+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
73+
6774
/**
6875
* Low-level Ollama API library.
6976
*/
@@ -72,61 +79,97 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
7279
/**
7380
* Default options to be used for all chat requests.
7481
*/
75-
private OllamaOptions defaultOptions;
82+
private final OllamaOptions defaultOptions;
83+
84+
/**
85+
* Observation registry used for instrumentation.
86+
*/
87+
private final ObservationRegistry observationRegistry;
7688

77-
public OllamaChatModel(OllamaApi chatApi) {
78-
this(chatApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
89+
/**
90+
* Conventions to use for generating observations.
91+
*/
92+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
93+
94+
public OllamaChatModel(OllamaApi ollamaApi) {
95+
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
7996
}
8097

81-
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions) {
82-
this(chatApi, defaultOptions, null);
98+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
99+
this(ollamaApi, defaultOptions, null);
83100
}
84101

85-
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
102+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
86103
FunctionCallbackContext functionCallbackContext) {
87-
this(chatApi, defaultOptions, functionCallbackContext, List.of());
104+
this(ollamaApi, defaultOptions, functionCallbackContext, List.of());
88105
}
89106

90-
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
107+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
91108
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
109+
this(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
110+
}
111+
112+
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
113+
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
114+
ObservationRegistry observationRegistry) {
92115
super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
93-
Assert.notNull(chatApi, "OllamaApi must not be null");
94-
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
116+
Assert.notNull(chatApi, "ollamaApi must not be null");
117+
Assert.notNull(defaultOptions, "defaultOptions must not be null");
118+
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
95119
this.chatApi = chatApi;
96120
this.defaultOptions = defaultOptions;
121+
this.observationRegistry = observationRegistry;
97122
}
98123

99124
@Override
100125
public ChatResponse call(Prompt prompt) {
126+
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);
101127

102-
OllamaApi.ChatResponse response = this.chatApi.chat(ollamaChatRequest(prompt, false));
128+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
129+
.prompt(prompt)
130+
.provider(OllamaApi.PROVIDER_NAME)
131+
.requestOptions(buildRequestOptions(request))
132+
.build();
103133

104-
List<AssistantMessage.ToolCall> toolCalls = response.message().toolCalls() == null ? List.of()
105-
: response.message()
106-
.toolCalls()
107-
.stream()
108-
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
109-
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
110-
.toList();
134+
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
135+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
136+
this.observationRegistry)
137+
.observe(() -> {
111138

112-
var assistantMessage = new AssistantMessage(response.message().content(), Map.of(), toolCalls);
139+
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
113140

114-
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
115-
if (response.promptEvalCount() != null && response.evalCount() != null) {
116-
generationMetadata = ChatGenerationMetadata.from(response.doneReason(), null);
117-
}
141+
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
142+
: ollamaResponse.message()
143+
.toolCalls()
144+
.stream()
145+
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
146+
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
147+
.toList();
148+
149+
var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);
150+
151+
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
152+
if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
153+
generationMetadata = ChatGenerationMetadata.from(ollamaResponse.doneReason(), null);
154+
}
155+
156+
var generator = new Generation(assistantMessage, generationMetadata);
157+
ChatResponse chatResponse = new ChatResponse(List.of(generator), from(ollamaResponse));
158+
159+
observationContext.setResponse(chatResponse);
160+
161+
return chatResponse;
118162

119-
var generator = new Generation(assistantMessage, generationMetadata);
120-
var chatResponse = new ChatResponse(List.of(generator), from(response));
163+
});
121164

122-
if (isToolCall(chatResponse, Set.of("stop"))) {
123-
var toolCallConversation = handleToolCalls(prompt, chatResponse);
165+
if (response != null && isToolCall(response, Set.of("stop"))) {
166+
var toolCallConversation = handleToolCalls(prompt, response);
124167
// Recursively call the call method with the tool call message
125168
// conversation that contains the call responses.
126169
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
127170
}
128171

129-
return chatResponse;
172+
return response;
130173
}
131174

132175
public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
@@ -147,40 +190,64 @@ public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
147190

148191
@Override
149192
public Flux<ChatResponse> stream(Prompt prompt) {
193+
return Flux.deferContextual(contextView -> {
194+
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true);
195+
196+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
197+
.prompt(prompt)
198+
.provider(OllamaApi.PROVIDER_NAME)
199+
.requestOptions(buildRequestOptions(request))
200+
.build();
201+
202+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
203+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
204+
this.observationRegistry);
205+
206+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
207+
208+
Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
209+
210+
Flux<ChatResponse> chatResponse = ollamaResponse.map(chunk -> {
211+
String content = (chunk.message() != null) ? chunk.message().content() : "";
212+
List<AssistantMessage.ToolCall> toolCalls = chunk.message().toolCalls() == null ? List.of()
213+
: chunk.message()
214+
.toolCalls()
215+
.stream()
216+
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
217+
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
218+
.toList();
219+
220+
var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
221+
222+
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
223+
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
224+
generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null);
225+
}
150226

151-
Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(ollamaChatRequest(prompt, true));
152-
153-
Flux<ChatResponse> chatResponse = ollamaResponse.map(chunk -> {
154-
String content = (chunk.message() != null) ? chunk.message().content() : "";
155-
List<AssistantMessage.ToolCall> toolCalls = chunk.message().toolCalls() == null ? List.of()
156-
: chunk.message()
157-
.toolCalls()
158-
.stream()
159-
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
160-
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
161-
.toList();
162-
163-
var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
164-
165-
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
166-
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
167-
generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null);
168-
}
169-
170-
var generator = new Generation(assistantMessage, generationMetadata);
171-
return new ChatResponse(List.of(generator), from(chunk));
172-
});
173-
174-
return chatResponse.flatMap(response -> {
175-
if (isToolCall(response, Set.of("stop"))) {
176-
var toolCallConversation = handleToolCalls(prompt, response);
177-
// Recursively call the stream method with the tool call message
178-
// conversation that contains the call responses.
179-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
180-
}
181-
else {
182-
return Flux.just(response);
183-
}
227+
var generator = new Generation(assistantMessage, generationMetadata);
228+
return new ChatResponse(List.of(generator), from(chunk));
229+
});
230+
231+
// @formatter:off
232+
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
233+
if (isToolCall(response, Set.of("stop"))) {
234+
var toolCallConversation = handleToolCalls(prompt, response);
235+
// Recursively call the stream method with the tool call message
236+
// conversation that contains the call responses.
237+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
238+
}
239+
else {
240+
return Flux.just(response);
241+
}
242+
})
243+
.doOnError(observation::error)
244+
.doFinally(s -> {
245+
observation.stop();
246+
})
247+
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
248+
// @formatter:on
249+
250+
return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
184251
});
185252
}
186253

@@ -216,13 +283,10 @@ else if (message instanceof AssistantMessage assistantMessage) {
216283
.build());
217284
}
218285
else if (message instanceof ToolResponseMessage toolMessage) {
219-
220-
List<OllamaApi.Message> responseMessages = toolMessage.getResponses()
286+
return toolMessage.getResponses()
221287
.stream()
222288
.map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build())
223289
.toList();
224-
225-
return responseMessages;
226290
}
227291
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
228292
}).flatMap(List::stream).toList();
@@ -290,9 +354,32 @@ private List<ChatRequest.Tool> getFunctionTools(Set<String> functionNames) {
290354
}).toList();
291355
}
292356

357+
private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) {
358+
var options = ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class);
359+
return ChatOptionsBuilder.builder()
360+
.withModel(request.model())
361+
.withFrequencyPenalty(options.getFrequencyPenalty())
362+
.withMaxTokens(options.getMaxTokens())
363+
.withPresencePenalty(options.getPresencePenalty())
364+
.withStopSequences(options.getStopSequences())
365+
.withTemperature(options.getTemperature())
366+
.withTopK(options.getTopK())
367+
.withTopP(options.getTopP())
368+
.build();
369+
}
370+
293371
@Override
294372
public ChatOptions getDefaultOptions() {
295373
return OllamaOptions.fromOptions(this.defaultOptions);
296374
}
297375

376+
/**
377+
* Use the provided convention for reporting observation data
378+
* @param observationConvention The provided convention
379+
*/
380+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
381+
Assert.notNull(observationConvention, "observationConvention cannot be null");
382+
this.observationConvention = observationConvention;
383+
}
384+
298385
}

0 commit comments

Comments
 (0)