Skip to content

Commit a6b8d51

Browse files
committed
add ollamaChat class
1 parent e009e86 commit a6b8d51

File tree

4 files changed

+611
-6
lines changed

4 files changed

+611
-6
lines changed

+llms/+internal/callOllamaChatAPI.m

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
function [text, message, response] = callOllamaChatAPI(model, messages, nvp)
2+
%callOllamaChatAPI Calls the ollama chat completions API.
3+
%
4+
% MESSAGES and FUNCTIONS should be structs matching the json format
5+
% required by the ollama Chat Completions API.
6+
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md
7+
%
8+
% Currently, the supported NVP are, including the equivalent name in the API:
9+
% TODO TODO TODO
10+
% - Temperature (temperature)
11+
% - TopProbabilityMass (top_p)
12+
% - NumCompletions (n)
13+
% - StopSequences (stop)
14+
% - MaxNumTokens (max_tokens)
15+
% - PresencePenalty (presence_penalty)
16+
% - FrequencyPenalty (frequence_penalty)
17+
% - ResponseFormat (response_format)
18+
% - Seed (seed)
19+
% - ApiKey
20+
% - TimeOut
21+
% - StreamFun
22+
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
23+
%
24+
% Example
25+
%
26+
% % Create messages struct
27+
% messages = {struct("role", "system",...
28+
% "content", "You are a helpful assistant");
29+
% struct("role", "user", ...
30+
% "content", "What is the edit distance between hi and hello?")};
31+
%
32+
% % Create functions struct
33+
% functions = {struct("name", "editDistance", ...
34+
% "description", "Find edit distance between two strings or documents.", ...
35+
% "parameters", struct( ...
36+
% "type", "object", ...
37+
% "properties", struct(...
38+
% "str1", struct(...
39+
% "description", "Source string.", ...
40+
% "type", "string"),...
41+
% "str2", struct(...
42+
% "description", "Target string.", ...
43+
% "type", "string")),...
44+
% "required", ["str1", "str2"]))};
45+
%
46+
% % Define your API key
47+
% apiKey = "your-api-key-here"
48+
%
49+
% % Send a request
50+
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
51+
52+
% Copyright 2023-2024 The MathWorks, Inc.
53+
54+
arguments
55+
model
56+
messages
57+
nvp.Temperature = 1
58+
nvp.TopProbabilityMass = 1
59+
nvp.NumCompletions = 1
60+
nvp.StopSequences = []
61+
nvp.MaxNumTokens = inf
62+
nvp.PresencePenalty = 0
63+
nvp.FrequencyPenalty = 0
64+
nvp.ResponseFormat = "text"
65+
nvp.Seed = []
66+
nvp.TimeOut = 10
67+
nvp.StreamFun = []
68+
end
69+
70+
URL = "http://localhost:11434/api/chat"; % TODO: model parameter
71+
72+
parameters = buildParametersCall(model, messages, nvp);
73+
74+
[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun);
75+
76+
% If call errors, "choices" will not be part of response.Body.Data, instead
77+
% we get response.Body.Data.error
78+
if response.StatusCode=="OK"
79+
% Outputs the first generation
80+
if isempty(nvp.StreamFun)
81+
message = response.Body.Data.message;
82+
else
83+
message = struct("role", "assistant", ...
84+
"content", streamedText);
85+
end
86+
text = string(message.content);
87+
else
88+
text = "";
89+
message = struct();
90+
end
91+
end
92+
93+
function parameters = buildParametersCall(model, messages, nvp)
94+
% Builds a struct in the format that is expected by the API, combining
95+
% MESSAGES, FUNCTIONS and parameters in NVP.
96+
97+
parameters = struct();
98+
parameters.model = model;
99+
parameters.messages = messages;
100+
101+
parameters.stream = ~isempty(nvp.StreamFun);
102+
103+
options = struct;
104+
if ~isempty(nvp.Seed)
105+
options.seed = nvp.Seed;
106+
end
107+
108+
dict = mapNVPToParameters;
109+
110+
nvpOptions = keys(dict);
111+
for opt = nvpOptions.'
112+
if isfield(nvp, opt)
113+
options.(dict(opt)) = nvp.(opt);
114+
end
115+
end
116+
117+
parameters.options = options;
118+
end
119+
120+
function dict = mapNVPToParameters()
121+
dict = dictionary();
122+
dict("Temperature") = "temperature";
123+
dict("TopProbabilityMass") = "top_p";
124+
dict("NumCompletions") = "n";
125+
dict("StopSequences") = "stop";
126+
dict("MaxNumTokens") = "num_predict";
127+
end

+llms/+internal/sendRequest.m

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,20 @@
1515

1616
% Define the headers for the API request
1717

18-
headers = [matlab.net.http.HeaderField('Content-Type', 'application/json')...
19-
matlab.net.http.HeaderField('Authorization', "Bearer " + token)...
20-
matlab.net.http.HeaderField('api-key',token)];
18+
headers = matlab.net.http.HeaderField('Content-Type', 'application/json');
19+
if ~isempty(token)
20+
headers = [headers ...
21+
matlab.net.http.HeaderField('Authorization', "Bearer " + token)...
22+
matlab.net.http.HeaderField('api-key',token)];
23+
end
2124

2225
% Define the request message
2326
request = matlab.net.http.RequestMessage('post',headers,parameters);
2427

25-
% Create a HTTPOptions object;
28+
% set the timeout
2629
httpOpts = matlab.net.http.HTTPOptions;
27-
28-
% Set the ConnectTimeout option
2930
httpOpts.ConnectTimeout = timeout;
31+
httpOpts.ResponseTimeout = timeout;
3032

3133
% Send the request and store the response
3234
if isempty(streamFun)

ollamaChat.m

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
classdef(Sealed) ollamaChat < llms.internal.textGenerator
2+
%ollamaChat Chat completion API from Azure.
3+
%
4+
% CHAT = ollamaChat(modelName) creates an ollamaChat object for the given model.
5+
%
6+
% CHAT = ollamaChat(__,systemPrompt) creates an ollamaChat object with the
7+
% specified system prompt.
8+
%
9+
% CHAT = ollamaChat(__,Name=Value) specifies additional options
10+
% using one or more name-value arguments:
11+
%
12+
% Temperature - Temperature value for controlling the randomness
13+
% of the output. Default value depends on the model;
14+
% if not specified in the model, defaults to 0.8.
15+
%
16+
% TODO: TopK and TopP, how do they relate to this?
17+
% TopProbabilityMass - Top probability mass value for controlling the
18+
% diversity of the output. Default value is 1.
19+
%
20+
% StopSequences - Vector of strings that when encountered, will
21+
% stop the generation of tokens. Default
22+
% value is empty.
23+
%
24+
% ResponseFormat - The format of response the model returns.
25+
% "text" (default) | "json"
26+
%
27+
% Seed - TODO: Seems to have no effect whatsoever (tested via curl) - cf. https://github.com/ollama/ollama/issues/4660
28+
%
29+
% Mirostat - 0/1/2, eta, tau
30+
%
31+
% RepeatLastN - find a better name! “Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)”
32+
%
33+
% RepeatPenalty
34+
%
35+
% TailFreeSamplingZ
36+
%
37+
% MaxNumTokens
38+
%
39+
% StreamFun - Function to callback when streaming the
40+
% result
41+
%
42+
% TimeOut - Connection Timeout in seconds (default: 10 secs)
43+
%
44+
%
45+
%
46+
% ollamaChat Functions:
47+
% ollamaChat - Chat completion API from OpenAI.
48+
% generate - Generate a response using the ollamaChat instance.
49+
%
50+
% ollamaChat Properties: TODO TODO
51+
% Temperature - Temperature of generation.
52+
%
53+
% TopProbabilityMass - Top probability mass to consider for generation.
54+
%
55+
% StopSequences - Sequences to stop the generation of tokens.
56+
%
57+
% PresencePenalty - Penalty for using a token in the
58+
% response that has already been used.
59+
%
60+
% FrequencyPenalty - Penalty for using a token that is
61+
% frequent in the training data.
62+
%
63+
% SystemPrompt - System prompt.
64+
%
65+
% ResponseFormat - Specifies the response format, text or json
66+
%
67+
% TimeOut - Connection Timeout in seconds (default: 10 secs)
68+
%
69+
70+
% Copyright 2024 The MathWorks, Inc.
71+
72+
properties(SetAccess=private)
73+
Model (1,1) string
74+
end
75+
76+
methods
77+
function this = ollamaChat(modelName, systemPrompt, nvp)
78+
arguments
79+
modelName {mustBeTextScalar}
80+
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
81+
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
82+
nvp.TopProbabilityMass {llms.utils.mustBeValidTopP} = 1
83+
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
84+
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
85+
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
86+
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
87+
end
88+
89+
if isfield(nvp,"StreamFun")
90+
this.StreamFun = nvp.StreamFun;
91+
else
92+
this.StreamFun = [];
93+
end
94+
95+
if ~isempty(systemPrompt)
96+
systemPrompt = string(systemPrompt);
97+
if ~(strlength(systemPrompt)==0)
98+
this.SystemPrompt = {struct("role", "system", "content", systemPrompt)};
99+
end
100+
end
101+
102+
this.Model = modelName;
103+
this.ResponseFormat = nvp.ResponseFormat;
104+
this.Temperature = nvp.Temperature;
105+
this.TopProbabilityMass = nvp.TopProbabilityMass;
106+
this.StopSequences = nvp.StopSequences;
107+
this.TimeOut = nvp.TimeOut;
108+
end
109+
110+
function [text, message, response] = generate(this, messages, nvp)
111+
%generate Generate a response using the azureChat instance.
112+
%
113+
% [TEXT, MESSAGE, RESPONSE] = generate(CHAT, MESSAGES) generates a response
114+
% with the specified MESSAGES.
115+
%
116+
% [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options
117+
% using one or more name-value arguments:
118+
%
119+
% NumCompletions - Number of completions to generate.
120+
% Default value is 1.
121+
%
122+
% MaxNumTokens - Maximum number of tokens in the generated response.
123+
% Default value is inf.
124+
%
125+
% ToolChoice - Function to execute. 'none', 'auto',
126+
% or specify the function to call.
127+
%
128+
% Seed - An integer value to use to obtain
129+
% reproducible responses
130+
%
131+
% Currently, GPT-4 Turbo with vision does not support the message.name
132+
% parameter, functions/tools, response_format parameter, stop
133+
% sequences, and max_tokens
134+
135+
arguments
136+
this (1,1) ollamaChat
137+
messages (1,1) {mustBeValidMsgs}
138+
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
139+
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
140+
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
141+
end
142+
143+
if isstring(messages) && isscalar(messages)
144+
messagesStruct = {struct("role", "user", "content", messages)};
145+
else
146+
messagesStruct = messages.Messages;
147+
end
148+
149+
if ~isempty(this.SystemPrompt)
150+
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
151+
end
152+
153+
[text, message, response] = llms.internal.callOllamaChatAPI(...
154+
this.Model, messagesStruct, ...
155+
Temperature=this.Temperature, ...
156+
NumCompletions=nvp.NumCompletions,...
157+
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
158+
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
159+
TimeOut=this.TimeOut, StreamFun=this.StreamFun);
160+
end
161+
end
162+
163+
methods(Hidden)
164+
function mustBeValidFunctionCall(this, functionCall)
165+
if ~isempty(functionCall)
166+
mustBeTextScalar(functionCall);
167+
if isempty(this.FunctionNames)
168+
error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall"));
169+
end
170+
mustBeMember(functionCall, ["none","auto", this.FunctionNames]);
171+
end
172+
end
173+
174+
function toolChoice = convertToolChoice(this, toolChoice)
175+
% if toolChoice is empty
176+
if isempty(toolChoice)
177+
% if Tools is not empty, the default is 'auto'.
178+
if ~isempty(this.Tools)
179+
toolChoice = "auto";
180+
end
181+
elseif ToolChoice ~= "auto"
182+
% if toolChoice is not empty, then it must be in the format
183+
% {"type": "function", "function": {"name": "my_function"}}
184+
toolChoice = struct("type","function","function",struct("name",toolChoice));
185+
end
186+
187+
end
188+
end
189+
end
190+
191+
function mustBeNonzeroLengthTextScalar(content)
192+
mustBeNonzeroLengthText(content)
193+
mustBeTextScalar(content)
194+
end
195+
196+
function [functionsStruct, functionNames] = functionAsStruct(functions)
197+
numFunctions = numel(functions);
198+
functionsStruct = cell(1, numFunctions);
199+
functionNames = strings(1, numFunctions);
200+
201+
for i = 1:numFunctions
202+
functionsStruct{i} = struct('type','function', ...
203+
'function',encodeStruct(functions(i))) ;
204+
functionNames(i) = functions(i).FunctionName;
205+
end
206+
end
207+
208+
function mustBeValidMsgs(value)
209+
if isa(value, "openAIMessages")
210+
if numel(value.Messages) == 0
211+
error("llms:mustHaveMessages", llms.utils.errorMessageCatalog.getMessage("llms:mustHaveMessages"));
212+
end
213+
else
214+
try
215+
llms.utils.mustBeNonzeroLengthTextScalar(value);
216+
catch ME
217+
error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt"));
218+
end
219+
end
220+
end
221+
222+
function mustBeIntegerOrEmpty(value)
223+
if ~isempty(value)
224+
mustBeInteger(value)
225+
end
226+
end

0 commit comments

Comments
 (0)