Skip to content

Commit dde7d95

Browse files
committed
Implement TopProbabilityNum and StopSequences for ollamaChat
1 parent 8a2ea28 commit dde7d95

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

+llms/+internal/callOllamaChatAPI.m

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,16 @@
2525
%
2626
% Example
2727
%
28+
% model = "mistral";
29+
%
2830
% % Create messages struct
2931
% messages = {struct("role", "system",...
3032
% "content", "You are a helpful assistant");
3133
% struct("role", "user", ...
3234
% "content", "What is the edit distance between hi and hello?")};
3335
%
34-
% % Create functions struct
35-
% functions = {struct("name", "editDistance", ...
36-
% "description", "Find edit distance between two strings or documents.", ...
37-
% "parameters", struct( ...
38-
% "type", "object", ...
39-
% "properties", struct(...
40-
% "str1", struct(...
41-
% "description", "Source string.", ...
42-
% "type", "string"),...
43-
% "str2", struct(...
44-
% "description", "Target string.", ...
45-
% "type", "string")),...
46-
% "required", ["str1", "str2"]))};
47-
%
48-
% % Define your API key
49-
% apiKey = "your-api-key-here"
50-
%
5136
% % Send a request
52-
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
37+
% [text, message] = llms.internal.callOllamaChatAPI(model, messages)
5338

5439
% Copyright 2023-2024 The MathWorks, Inc.
5540

@@ -58,6 +43,7 @@
5843
messages
5944
nvp.Temperature = 1
6045
nvp.TopProbabilityMass = 1
46+
nvp.TopProbabilityNum = Inf
6147
nvp.NumCompletions = 1
6248
nvp.StopSequences = []
6349
nvp.MaxNumTokens = inf
@@ -71,6 +57,12 @@
7157

7258
URL = "http://localhost:11434/api/chat"; % TODO: model parameter
7359

60+
% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
61+
% The easiest way to ensure that is to never pass in a scalar …
62+
if isscalar(nvp.StopSequences)
63+
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences];
64+
end
65+
7466
parameters = buildParametersCall(model, messages, nvp);
7567

7668
[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun);
@@ -123,6 +115,7 @@
123115
dict = dictionary();
124116
dict("Temperature") = "temperature";
125117
dict("TopProbabilityMass") = "top_p";
118+
dict("TopProbabilityNum") = "top_k";
126119
dict("NumCompletions") = "n";
127120
dict("StopSequences") = "stop";
128121
dict("MaxNumTokens") = "num_predict";

ollamaChat.m

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
classdef(Sealed) ollamaChat < llms.internal.textGenerator
1+
classdef (Sealed) ollamaChat < llms.internal.textGenerator
22
%ollamaChat Chat completion API from Azure.
33
%
44
% CHAT = ollamaChat(modelName) creates an ollamaChat object for the given model.
@@ -68,8 +68,9 @@
6868

6969
% Copyright 2024 The MathWorks, Inc.
7070

71-
properties(SetAccess=private)
71+
properties
7272
Model (1,1) string
73+
TopProbabilityNum (1,1) {mustBeReal,mustBePositive} = Inf
7374
end
7475

7576
methods
@@ -79,6 +80,7 @@
7980
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
8081
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
8182
nvp.TopProbabilityMass {llms.utils.mustBeValidTopP} = 1
83+
nvp.TopProbabilityNum (1,1) {mustBeReal,mustBePositive} = Inf
8284
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
8385
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
8486
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
@@ -102,6 +104,7 @@
102104
this.ResponseFormat = nvp.ResponseFormat;
103105
this.Temperature = nvp.Temperature;
104106
this.TopProbabilityMass = nvp.TopProbabilityMass;
107+
this.TopProbabilityNum = nvp.TopProbabilityNum;
105108
this.StopSequences = nvp.StopSequences;
106109
this.TimeOut = nvp.TimeOut;
107110
end
@@ -152,6 +155,7 @@
152155
[text, message, response] = llms.internal.callOllamaChatAPI(...
153156
this.Model, messagesStruct, ...
154157
Temperature=this.Temperature, ...
158+
TopProbabilityMass=this.TopProbabilityMass, TopProbabilityNum=this.TopProbabilityNum,...
155159
NumCompletions=nvp.NumCompletions,...
156160
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
157161
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...

tests/tollamaChat.m

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
InvalidConstructorInput = iGetInvalidConstructorInput;
88
InvalidGenerateInput = iGetInvalidGenerateInput;
99
InvalidValuesSetters = iGetInvalidValuesSetters;
10+
ValidValuesSetters = iGetValidValuesSetters;
1011
end
1112

1213
methods(Test)
@@ -37,6 +38,26 @@ function doGenerate(testCase)
3738
testCase.verifyGreaterThan(strlength(response),0);
3839
end
3940

41+
function extremeTopK(testCase)
42+
% setting top-k to k=1 leaves no random choice,
43+
% so we expect to get a fixed response.
44+
chat = ollamaChat("mistral",TopProbabilityNum=1);
45+
prompt = "Top-k sampling with k=1 returns a definite answer.";
46+
response1 = generate(chat,prompt);
47+
response2 = generate(chat,prompt);
48+
testCase.verifyEqual(response1,response2);
49+
end
50+
51+
function stopSequences(testCase)
52+
chat = ollamaChat("mistral",TopProbabilityNum=1);
53+
prompt = "Top-k sampling with k=1 returns a definite answer.";
54+
response1 = generate(chat,prompt);
55+
chat.StopSequences = "1";
56+
response2 = generate(chat,prompt);
57+
58+
testCase.verifyEqual(response2, extractBefore(response1,"1"));
59+
end
60+
4061
%% Test is currently unreliable, reasons unclear
4162
% function verySmallTimeOutErrors(testCase)
4263
% chat = ollamaChat("mistral", TimeOut=1e-10);
@@ -60,6 +81,15 @@ function assignValueToProperty(property, value)
6081

6182
testCase.verifyError(@() assignValueToProperty(InvalidValuesSetters.Property,InvalidValuesSetters.Value), InvalidValuesSetters.Error);
6283
end
84+
85+
function validSetters(testCase, ValidValuesSetters)
86+
chat = ollamaChat("mistral");
87+
function assignValueToProperty(property, value)
88+
chat.(property) = value;
89+
end
90+
91+
testCase.verifyWarningFree(@() assignValueToProperty(ValidValuesSetters.Property,ValidValuesSetters.Value));
92+
end
6393
end
6494
end
6595

@@ -119,17 +149,19 @@ function assignValueToProperty(property, value)
119149
"EmptyStopSequences", struct( ...
120150
"Property", "StopSequences", ...
121151
"Value", "", ...
122-
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
123-
...
124-
"WrongSizeStopSequences", struct( ...
125-
"Property", "StopSequences", ...
126-
"Value", ["1" "2" "3" "4" "5"], ...
127-
"Error", "llms:stopSequencesMustHaveMax4Elements"), ...
128-
...
129-
"InvalidPresencePenalty", struct( ...
130-
"Property", "PresencePenalty", ...
131-
"Value", "2", ...
132-
"Error", "MATLAB:invalidType"));
152+
"Error", "MATLAB:validators:mustBeNonzeroLengthText"));
153+
end
154+
155+
function validSetters = iGetValidValuesSetters
156+
validSetters = struct(...
157+
"SmallTopNum", struct( ...
158+
"Property", "TopProbabilityNum", ...
159+
"Value", 2));
160+
% Currently disabled because it requires some code reorganization
161+
% and we have higher priorities ...
162+
% "ManyStopSequences", struct( ...
163+
% "Property", "StopSequences", ...
164+
% "Value", ["1" "2" "3" "4" "5"]));
133165
end
134166

135167
function invalidConstructorInput = iGetInvalidConstructorInput
@@ -196,7 +228,7 @@ function assignValueToProperty(property, value)
196228
...
197229
"TopProbabilityMassTooSmall",struct( ...
198230
"Input",{{ "TopProbabilityMass" -20 }},...
199-
"Error","MATLAB:expectedNonnegative"),...
231+
"Error","MATLAB:expectedNonnegative"),...I
200232
...
201233
"WrongTypeStopSequences",struct( ...
202234
"Input",{{ "StopSequences" 123}},...
@@ -208,11 +240,7 @@ function assignValueToProperty(property, value)
208240
...
209241
"EmptyStopSequences",struct( ...
210242
"Input",{{ "StopSequences" ""}},...
211-
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
212-
...
213-
"WrongSizeStopSequences",struct( ...
214-
"Input",{{ "StopSequences" ["1" "2" "3" "4" "5"]}},...
215-
"Error","llms:stopSequencesMustHaveMax4Elements"));
243+
"Error","MATLAB:validators:mustBeNonzeroLengthText"));
216244
end
217245

218246
function invalidGenerateInput = iGetInvalidGenerateInput

0 commit comments

Comments
 (0)