Skip to content

Commit fbfa569

Browse files
committed
Added addToolMessage to openAIMessages
AddFunctionMessage was rennamed to addToolMessage, added a new example to the existing MLX and error catalog was also updated.
1 parent adbddc2 commit fbfa569

File tree

6 files changed

+76
-20
lines changed

6 files changed

+76
-20
lines changed

+llms/+utils/errorMessageCatalog.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
catalog("llms:parameterMustBeUnique") = "A parameter name equivalent to '{1}' already exists in Parameters. Redefining a parameter is not allowed.";
3939
catalog("llms:mustBeAssistantCall") = "Input struct must contain field 'role' with value 'assistant', and field 'content'.";
4040
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters.";
41-
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'.";
41+
catalog("llms:mustBeAssistantWithIdAndFunction") = "Field 'tool_call' must be a struct with fields 'id' and 'function'.";
42+
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function' must be a struct with fields 'name' and 'arguments'.";
4243
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
4344
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
4445
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";

examples/ExampleChatBot.mlx

415 Bytes
Binary file not shown.
1.88 KB
Binary file not shown.

openAIMessages.m

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,31 @@
139139

140140
end
141141

142-
function this = addFunctionMessage(this, name, content)
143-
%addFunctionMessage Add function message.
142+
function this = addToolMessage(this, id, name, content)
143+
%addToolMessage Add Tool message.
144144
%
145-
% MESSAGES = addFunctionMessage(MESSES, NAME, CONTENT) adds a function
146-
% message with the specified name and content. NAME and
147-
% CONTENT must be text scalars.
145+
% MESSAGES = addFunctionMessage(MESSAGES, ID, NAME, CONTENT)
146+
% adds a tool message with the specified id, name and content.
147+
% ID, NAME and CONTENT must be text scalars.
148148
%
149149
% Example:
150150
% % Create messages object
151151
% messages = openAIMessages;
152152
%
153153
% % Add function message, containing the result of
154154
% % calling strcat("Hello", " World")
155-
% messages = addFunctionMessage(messages, "strcat", "Hello World");
155+
% messages = addToolMessage(messages, "call_123", "strcat", "Hello World");
156156

157157
arguments
158158
this (1,1) openAIMessages
159+
id {mustBeNonzeroLengthTextScalar}
159160
name {mustBeNonzeroLengthTextScalar}
160161
content {mustBeNonzeroLengthTextScalar}
162+
161163
end
162164

163-
newMessage = struct("role", "function", "name", string(name), "content", string(content));
165+
newMessage = struct("tool_call_id", id, "role", "tool", ...
166+
"name", string(name), "content", string(content));
164167
this.Messages{end+1} = newMessage;
165168
end
166169

@@ -201,9 +204,9 @@
201204

202205
% Assistant is asking for function call
203206
if isfield(messageStruct, "tool_calls")
204-
toolCall = messageStruct.tool_calls{1};
205-
validateAssistantWithFunctionCall(toolCall.function)
206-
this = addAssistantMessage(this, messageStruct.content, toolCall);
207+
toolCalls = messageStruct.tool_calls;
208+
validateAssistantWithToolCalls(toolCalls)
209+
this = addAssistantMessage(this, messageStruct.content, toolCalls);
207210
else
208211
% Simple assistant response
209212
validateRegularAssistant(messageStruct.content);
@@ -254,10 +257,19 @@
254257
newMessage = struct("role", "assistant", "content", content);
255258
else
256259
% tool_calls message
257-
functionCall = struct("name", toolCalls.function.name, "arguments", toolCalls.function.arguments);
258-
toolsStruct = struct("id", toolCalls.id, "type", toolCalls.type, "function", functionCall);
260+
toolsStruct = repmat(struct("id",[],"type",[],"function",[]),size(toolCalls));
261+
for i = 1:numel(toolCalls)
262+
toolsStruct(i).id = toolCalls(i).id;
263+
toolsStruct(i).type = toolCalls(i).type;
264+
toolsStruct(i).function = struct( ...
265+
"name", toolCalls(i).function.name, ...
266+
"arguments", toolCalls(i).function.arguments);
267+
end
268+
259269
newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
260-
newMessage.tool_calls = {newMessage.tool_calls};
270+
if numel(newMessage.tool_calls) == 1
271+
newMessage.tool_calls = {newMessage.tool_calls};
272+
end
261273
end
262274

263275
if isempty(this.Messages)
@@ -283,17 +295,26 @@ function validateRegularAssistant(content)
283295
end
284296
end
285297

286-
function validateAssistantWithFunctionCall(functionCallStruct)
287-
if ~isstruct(functionCallStruct)||~isfield(functionCallStruct, "name")||~isfield(functionCallStruct, "arguments")
298+
function validateAssistantWithToolCalls(toolCallStruct)
299+
if ~isstruct(toolCallStruct)||~isfield(toolCallStruct, "id")||~isfield(toolCallStruct, "function")
300+
error("llms:mustBeAssistantWithIdAndFunction", ...
301+
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction"))
302+
else
303+
functionCallStruct = [toolCallStruct.function];
304+
end
305+
306+
if ~isfield(functionCallStruct, "name")||~isfield(functionCallStruct, "arguments")
288307
error("llms:mustBeAssistantWithNameAndArguments", ...
289308
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithNameAndArguments"))
290309
end
291310

292311
try
293-
mustBeNonzeroLengthText(functionCallStruct.name)
294-
mustBeTextScalar(functionCallStruct.name)
295-
mustBeNonzeroLengthText(functionCallStruct.arguments)
296-
mustBeTextScalar(functionCallStruct.arguments)
312+
for i = 1:numel(functionCallStruct)
313+
mustBeNonzeroLengthText(functionCallStruct(i).name)
314+
mustBeTextScalar(functionCallStruct(i).name)
315+
mustBeNonzeroLengthText(functionCallStruct(i).arguments)
316+
mustBeTextScalar(functionCallStruct(i).arguments)
317+
end
297318
catch ME
298319
error("llms:assistantMustHaveTextNameAndArguments", ...
299320
llms.utils.errorMessageCatalog.getMessage("llms:assistantMustHaveTextNameAndArguments"))

0 commit comments

Comments
 (0)