Skip to content

Commit 19effe4

Browse files
committed
Let generate temporarily override model settings (#2)
1 parent 9ce8f94 commit 19effe4

File tree

7 files changed

+266
-35
lines changed

7 files changed

+266
-35
lines changed

azureChat.m

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,59 @@
174174
% Seed - An integer value to use to obtain
175175
% reproducible responses
176176
%
177+
% Temperature - Temperature value for controlling the randomness
178+
% of the output. The default value is CHAT.Temperature;
179+
% higher values increase the randomness (in some sense,
180+
% the “creativity”) of outputs, lower values
181+
% reduce it. Setting Temperature=0 removes
182+
% randomness from the output altogether.
183+
%
184+
% TopP - Top probability mass value for controlling the
185+
% diversity of the output. Default value is CHAT.TopP;
186+
% lower values imply that only the more likely
187+
% words can appear in any particular place.
188+
% This is also known as top-p sampling.
189+
%
190+
% StopSequences - Vector of strings that when encountered, will
191+
% stop the generation of tokens. Default
192+
% value is CHAT.StopSequences.
193+
% Example: ["The end.", "And that's all she wrote."]
194+
%
195+
% ResponseFormat - The format of response the model returns.
196+
% Default value is CHAT.ResponseFormat.
197+
% "text" | "json"
198+
%
199+
% PresencePenalty - Penalty value for using a token in the response
200+
% that has already been used. Default value is
201+
% CHAT.PresencePenalty.
202+
% Higher values reduce repetition of words in the output.
203+
%
204+
% FrequencyPenalty - Penalty value for using a token that is frequent
205+
% in the output. Default value is CHAT.FrequencyPenalty.
206+
% Higher values reduce repetition of words in the output.
207+
%
208+
% StreamFun - Function to callback when streaming the result.
209+
% Default value is CHAT.StreamFun.
210+
%
211+
% TimeOut - Connection Timeout in seconds. Default value is CHAT.TimeOut.
212+
%
213+
%
177214
% Currently, GPT-4 Turbo with vision does not support the message.name
178215
% parameter, functions/tools, response_format parameter, stop
179216
% sequences, and max_tokens
180217

181218
arguments
182219
this (1,1) azureChat
183220
messages {mustBeValidMsgs}
221+
nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature
222+
nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP
223+
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
224+
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
225+
nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey
226+
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty
227+
nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty
228+
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
229+
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
184230
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
185231
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
186232
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
@@ -199,15 +245,22 @@
199245
end
200246

201247
toolChoice = convertToolChoice(this, nvp.ToolChoice);
248+
249+
if isfield(nvp,"StreamFun")
250+
streamFun = nvp.StreamFun;
251+
else
252+
streamFun = this.StreamFun;
253+
end
254+
202255
try
203256
[text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ...
204257
this.DeploymentID, messagesStruct, this.FunctionsStruct, ...
205-
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ...
206-
TopP=this.TopP, NumCompletions=nvp.NumCompletions,...
207-
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
208-
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
209-
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
210-
APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
258+
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=nvp.Temperature, ...
259+
TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,...
260+
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
261+
PresencePenalty=nvp.PresencePenalty, FrequencyPenalty=nvp.FrequencyPenalty, ...
262+
ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ...
263+
APIKey=nvp.APIKey,TimeOut=nvp.TimeOut,StreamFun=streamFun);
211264
catch ME
212265
if ismember(ME.identifier,...
213266
["MATLAB:webservices:UnknownHost","MATLAB:webservices:Timeout"])

functionSignatures.json

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,17 @@
3131
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
3232
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
3333
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
34-
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
34+
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
35+
{"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"},
36+
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
37+
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
38+
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
39+
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
40+
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
41+
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
42+
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
43+
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
44+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
3545
],
3646
"outputs":
3747
[
@@ -73,7 +83,16 @@
7383
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
7484
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
7585
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
76-
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
86+
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
87+
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
88+
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
89+
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
90+
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
91+
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
92+
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
93+
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
94+
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
95+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
7796
],
7897
"outputs":
7998
[
@@ -90,12 +109,14 @@
90109
{"name":"systemPrompt","kind":"ordered","type":["string","scalar"]},
91110
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
92111
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
112+
{"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
93113
{"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]},
94114
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
95115
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
96116
{"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]},
97117
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
98-
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
118+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
119+
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}
99120
],
100121
"outputs":
101122
[
@@ -109,7 +130,18 @@
109130
{"name":"this","kind":"required","type":["ollamaChat","scalar"]},
110131
{"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]},
111132
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
112-
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
133+
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
134+
{"name":"Model","kind":"namevalue","type":"choices=ollamaChat.models"},
135+
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
136+
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
137+
{"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
138+
{"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]},
139+
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
140+
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
141+
{"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]},
142+
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
143+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
144+
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}
113145
],
114146
"outputs":
115147
[

ollamaChat.m

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
% value is empty.
4848
% Example: ["The end.", "And that's all she wrote."]
4949
%
50-
%
5150
% ResponseFormat - The format of response the model returns.
5251
% "text" (default) | "json"
5352
%
@@ -128,17 +127,79 @@
128127
% [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options
129128
% using one or more name-value arguments:
130129
%
131-
% MaxNumTokens - Maximum number of tokens in the generated response.
132-
% Default value is inf.
130+
% MaxNumTokens - Maximum number of tokens in the generated response.
131+
% Default value is inf.
132+
%
133+
% Seed - An integer value to use to obtain
134+
% reproducible responses
135+
%
136+
% Model - Model name (as expected by Ollama server).
137+
% Default value is CHAT.Model.
138+
%
139+
% Temperature - Temperature value for controlling the randomness
140+
% of the output. Default value is CHAT.Temperature.
141+
% Higher values increase the randomness (in some
142+
% sense, the “creativity”) of outputs, lower
143+
% values reduce it. Setting Temperature=0 removes
144+
% randomness from the output altogether.
145+
%
146+
% TopP - Top probability mass value for controlling the
147+
% diversity of the output. Default value is CHAT.TopP;
148+
% lower values imply that only the more likely
149+
% words can appear in any particular place.
150+
% This is also known as top-p sampling.
151+
%
152+
% MinP - Minimum probability ratio for controlling the
153+
% diversity of the output. Default value is CHAT.MinP;
154+
% higher values imply that only the more likely
155+
% words can appear in any particular place.
156+
% This is also known as min-p sampling.
157+
%
158+
% TopK - Maximum number of most likely tokens that are
159+
% considered for output. Default is CHAT.TopK.
160+
% Smaller values reduce diversity in the output.
161+
%
162+
% TailFreeSamplingZ - Reduce the use of less probable tokens, based on
163+
% the second-order differences of ordered
164+
% probabilities.
165+
% Default value is CHAT.TailFreeSamplingZ.
166+
% Lower values reduce diversity, with
167+
% some authors recommending values
168+
% around 0.95. Tail-free sampling is
169+
% slower than using TopP or TopK.
170+
%
171+
% StopSequences - Vector of strings that when encountered, will
172+
% stop the generation of tokens. Default
173+
% value is CHAT.StopSequences.
174+
% Example: ["The end.", "And that's all she wrote."]
175+
%
176+
%
177+
% ResponseFormat - The format of response the model returns.
178+
% The default value is CHAT.ResponseFormat.
179+
% "text" (default) | "json"
180+
%
181+
% StreamFun - Function to callback when streaming the
182+
% result. The default value is CHAT.StreamFun.
183+
%
184+
% TimeOut - Connection Timeout in seconds. Default is CHAT.TimeOut.
133185
%
134-
% Seed - An integer value to use to obtain
135-
% reproducible responses
136186

137187
arguments
138188
this (1,1) ollamaChat
139-
messages {mustBeValidMsgs}
189+
messages {mustBeValidMsgs}
190+
nvp.Model {mustBeTextScalar} = this.Model
191+
nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature
192+
nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP
193+
nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP
194+
nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK
195+
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
196+
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
197+
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
198+
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ
199+
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
200+
nvp.Endpoint (1,1) string = this.Endpoint
140201
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
141-
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
202+
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
142203
end
143204

144205
messages = convertCharsToStrings(messages);
@@ -152,15 +213,21 @@
152213
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
153214
end
154215

216+
if isfield(nvp,"StreamFun")
217+
streamFun = nvp.StreamFun;
218+
else
219+
streamFun = this.StreamFun;
220+
end
221+
155222
[text, message, response] = llms.internal.callOllamaChatAPI(...
156-
this.Model, messagesStruct, ...
157-
Temperature=this.Temperature, ...
158-
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
159-
TailFreeSamplingZ=this.TailFreeSamplingZ,...
160-
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
161-
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
162-
TimeOut=this.TimeOut, StreamFun=this.StreamFun, ...
163-
Endpoint=this.Endpoint);
223+
nvp.Model, messagesStruct, ...
224+
Temperature=nvp.Temperature, ...
225+
TopP=nvp.TopP, MinP=nvp.MinP, TopK=nvp.TopK,...
226+
TailFreeSamplingZ=nvp.TailFreeSamplingZ,...
227+
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
228+
ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ...
229+
TimeOut=nvp.TimeOut, StreamFun=streamFun, ...
230+
Endpoint=nvp.Endpoint);
164231

165232
if isfield(response.Body.Data,"error")
166233
err = response.Body.Data.error;

0 commit comments

Comments
 (0)