1
+ classdef (Sealed ) ollamaChat < llms .internal .textGenerator
2
+ % ollamaChat Chat completion API from Azure.
3
+ %
4
+ % CHAT = ollamaChat(modelName) creates an ollamaChat object for the given model.
5
+ %
6
+ % CHAT = ollamaChat(__,systemPrompt) creates an ollamaChat object with the
7
+ % specified system prompt.
8
+ %
9
+ % CHAT = ollamaChat(__,Name=Value) specifies additional options
10
+ % using one or more name-value arguments:
11
+ %
12
+ % Temperature - Temperature value for controlling the randomness
13
+ % of the output. Default value depends on the model;
14
+ % if not specified in the model, defaults to 0.8.
15
+ %
16
+ % TODO: TopK and TopP, how do they relate to this?
17
+ % TopProbabilityMass - Top probability mass value for controlling the
18
+ % diversity of the output. Default value is 1.
19
+ %
20
+ % StopSequences - Vector of strings that when encountered, will
21
+ % stop the generation of tokens. Default
22
+ % value is empty.
23
+ %
24
+ % ResponseFormat - The format of response the model returns.
25
+ % "text" (default) | "json"
26
+ %
27
+ % Seed - TODO: Seems to have no effect whatsoever (tested via curl) - cf. https://github.com/ollama/ollama/issues/4660
28
+ %
29
+ % Mirostat - 0/1/2, eta, tau
30
+ %
31
+ % RepeatLastN - find a better name! “Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)”
32
+ %
33
+ % RepeatPenalty
34
+ %
35
+ % TailFreeSamplingZ
36
+ %
37
+ % MaxNumTokens
38
+ %
39
+ % StreamFun - Function to callback when streaming the
40
+ % result
41
+ %
42
+ % TimeOut - Connection Timeout in seconds (default: 10 secs)
43
+ %
44
+ %
45
+ %
46
+ % ollamaChat Functions:
47
+ % ollamaChat - Chat completion API from OpenAI.
48
+ % generate - Generate a response using the ollamaChat instance.
49
+ %
50
+ % ollamaChat Properties: TODO TODO
51
+ % Temperature - Temperature of generation.
52
+ %
53
+ % TopProbabilityMass - Top probability mass to consider for generation.
54
+ %
55
+ % StopSequences - Sequences to stop the generation of tokens.
56
+ %
57
+ % PresencePenalty - Penalty for using a token in the
58
+ % response that has already been used.
59
+ %
60
+ % FrequencyPenalty - Penalty for using a token that is
61
+ % frequent in the training data.
62
+ %
63
+ % SystemPrompt - System prompt.
64
+ %
65
+ % ResponseFormat - Specifies the response format, text or json
66
+ %
67
+ % TimeOut - Connection Timeout in seconds (default: 10 secs)
68
+ %
69
+
70
+ % Copyright 2024 The MathWorks, Inc.
71
+
72
+ properties (SetAccess = private )
73
+ Model (1 ,1 ) string
74
+ end
75
+
76
+ methods
77
+ function this = ollamaChat(modelName , systemPrompt , nvp )
78
+ arguments
79
+ modelName {mustBeTextScalar }
80
+ systemPrompt {llms .utils .mustBeTextOrEmpty } = []
81
+ nvp.Temperature {llms .utils .mustBeValidTemperature } = 1
82
+ nvp.TopProbabilityMass {llms .utils .mustBeValidTopP } = 1
83
+ nvp.StopSequences {llms .utils .mustBeValidStop } = {}
84
+ nvp.ResponseFormat (1 ,1 ) string {mustBeMember(nvp .ResponseFormat ,[" text" ," json" ])} = " text"
85
+ nvp.TimeOut (1 ,1 ) {mustBeReal ,mustBePositive } = 10
86
+ nvp.StreamFun (1 ,1 ) {mustBeA(nvp .StreamFun ,' function_handle' )}
87
+ end
88
+
89
+ if isfield(nvp ," StreamFun" )
90
+ this.StreamFun = nvp .StreamFun ;
91
+ else
92
+ this.StreamFun = [];
93
+ end
94
+
95
+ if ~isempty(systemPrompt )
96
+ systemPrompt = string(systemPrompt );
97
+ if ~(strlength(systemPrompt )==0 )
98
+ this.SystemPrompt = {struct(" role" , " system" , " content" , systemPrompt )};
99
+ end
100
+ end
101
+
102
+ this.Model = modelName ;
103
+ this.ResponseFormat = nvp .ResponseFormat ;
104
+ this.Temperature = nvp .Temperature ;
105
+ this.TopProbabilityMass = nvp .TopProbabilityMass ;
106
+ this.StopSequences = nvp .StopSequences ;
107
+ this.TimeOut = nvp .TimeOut ;
108
+ end
109
+
110
+ function [text , message , response ] = generate(this , messages , nvp )
111
+ % generate Generate a response using the azureChat instance.
112
+ %
113
+ % [TEXT, MESSAGE, RESPONSE] = generate(CHAT, MESSAGES) generates a response
114
+ % with the specified MESSAGES.
115
+ %
116
+ % [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options
117
+ % using one or more name-value arguments:
118
+ %
119
+ % NumCompletions - Number of completions to generate.
120
+ % Default value is 1.
121
+ %
122
+ % MaxNumTokens - Maximum number of tokens in the generated response.
123
+ % Default value is inf.
124
+ %
125
+ % ToolChoice - Function to execute. 'none', 'auto',
126
+ % or specify the function to call.
127
+ %
128
+ % Seed - An integer value to use to obtain
129
+ % reproducible responses
130
+ %
131
+ % Currently, GPT-4 Turbo with vision does not support the message.name
132
+ % parameter, functions/tools, response_format parameter, stop
133
+ % sequences, and max_tokens
134
+
135
+ arguments
136
+ this (1 ,1 ) ollamaChat
137
+ messages (1 ,1 ) {mustBeValidMsgs }
138
+ nvp.NumCompletions (1 ,1 ) {mustBePositive , mustBeInteger } = 1
139
+ nvp.MaxNumTokens (1 ,1 ) {mustBePositive } = inf
140
+ nvp.Seed {mustBeIntegerOrEmpty(nvp .Seed )} = []
141
+ end
142
+
143
+ if isstring(messages ) && isscalar(messages )
144
+ messagesStruct = {struct(" role" , " user" , " content" , messages )};
145
+ else
146
+ messagesStruct = messages .Messages ;
147
+ end
148
+
149
+ if ~isempty(this .SystemPrompt )
150
+ messagesStruct = horzcat(this .SystemPrompt , messagesStruct );
151
+ end
152
+
153
+ [text , message , response ] = llms .internal .callOllamaChatAPI(...
154
+ this .Model , messagesStruct , ...
155
+ Temperature= this .Temperature , ...
156
+ NumCompletions= nvp .NumCompletions ,...
157
+ StopSequences= this .StopSequences , MaxNumTokens= nvp .MaxNumTokens , ...
158
+ ResponseFormat= this .ResponseFormat ,Seed= nvp .Seed , ...
159
+ TimeOut= this .TimeOut , StreamFun= this .StreamFun );
160
+ end
161
+ end
162
+
163
+ methods (Hidden )
164
+ function mustBeValidFunctionCall(this , functionCall )
165
+ if ~isempty(functionCall )
166
+ mustBeTextScalar(functionCall );
167
+ if isempty(this .FunctionNames )
168
+ error(" llms:mustSetFunctionsForCall" , llms .utils .errorMessageCatalog .getMessage(" llms:mustSetFunctionsForCall" ));
169
+ end
170
+ mustBeMember(functionCall , [" none" ," auto" , this .FunctionNames ]);
171
+ end
172
+ end
173
+
174
+ function toolChoice = convertToolChoice(this , toolChoice )
175
+ % if toolChoice is empty
176
+ if isempty(toolChoice )
177
+ % if Tools is not empty, the default is 'auto'.
178
+ if ~isempty(this .Tools )
179
+ toolChoice = " auto" ;
180
+ end
181
+ elseif ToolChoice ~= " auto"
182
+ % if toolChoice is not empty, then it must be in the format
183
+ % {"type": "function", "function": {"name": "my_function"}}
184
+ toolChoice = struct(" type" ," function" ," function" ,struct(" name" ,toolChoice ));
185
+ end
186
+
187
+ end
188
+ end
189
+ end
190
+
191
+ function mustBeNonzeroLengthTextScalar(content )
192
+ mustBeNonzeroLengthText(content )
193
+ mustBeTextScalar(content )
194
+ end
195
+
196
+ function [functionsStruct , functionNames ] = functionAsStruct(functions )
197
+ numFunctions = numel(functions );
198
+ functionsStruct = cell(1 , numFunctions );
199
+ functionNames = strings(1 , numFunctions );
200
+
201
+ for i = 1 : numFunctions
202
+ functionsStruct{i } = struct(' type' ,' function' , ...
203
+ ' function' ,encodeStruct(functions(i ))) ;
204
+ functionNames(i ) = functions(i ).FunctionName;
205
+ end
206
+ end
207
+
208
+ function mustBeValidMsgs(value )
209
+ if isa(value , " openAIMessages" )
210
+ if numel(value .Messages ) == 0
211
+ error(" llms:mustHaveMessages" , llms .utils .errorMessageCatalog .getMessage(" llms:mustHaveMessages" ));
212
+ end
213
+ else
214
+ try
215
+ llms .utils .mustBeNonzeroLengthTextScalar(value );
216
+ catch ME
217
+ error(" llms:mustBeMessagesOrTxt" , llms .utils .errorMessageCatalog .getMessage(" llms:mustBeMessagesOrTxt" ));
218
+ end
219
+ end
220
+ end
221
+
222
+ function mustBeIntegerOrEmpty(value )
223
+ if ~isempty(value )
224
+ mustBeInteger(value )
225
+ end
226
+ end
0 commit comments