Skip to content

Commit 3bf8430

Browse files
committed
minor bugfixes - but the test still fails
1 parent ec0d493 commit 3bf8430

File tree

3 files changed

+41
-27
lines changed

3 files changed

+41
-27
lines changed
245 Bytes
Binary file not shown.

openAIMessages.m

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,11 @@
265265
"name", toolCalls(i).function.name, ...
266266
"arguments", toolCalls(i).function.arguments);
267267
end
268-
269-
newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
270-
if numel(newMessage.tool_calls) == 1
271-
newMessage.tool_calls = {newMessage.tool_calls};
268+
if numel(toolsStruct) > 1
269+
newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
270+
else
271+
newMessage = struct("role", "assistant", "content", content, "tool_calls", []);
272+
newMessage.tool_calls = {toolsStruct};
272273
end
273274
end
274275

tests/topenAIMessages.m

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function differentInputTextAccepted(testCase, ValidTextInput)
2424
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
2525
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
2626
testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput));
27-
testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput));
27+
testCase.verifyWarningFree(@()addToolMessage(msgs, ValidTextInput, ValidTextInput, ValidTextInput));
2828
end
2929

3030

@@ -59,12 +59,13 @@ function userImageMessageIsAddedWithRemoteImg(testCase)
5959
testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img));
6060
end
6161

62-
function functionMessageIsAdded(testCase)
62+
function toolMessageIsAdded(testCase)
6363
prompt = "20";
6464
name = "sin";
65+
id = "123";
6566
msgs = openAIMessages;
66-
systemPrompt = struct("role", "function", "name", name, "content", prompt);
67-
msgs = addFunctionMessage(msgs, name, prompt);
67+
systemPrompt = struct("tool_call_id", id, "role", "tool", "name", name, "content", prompt);
68+
msgs = addToolMessage(msgs, id, name, prompt);
6869
testCase.verifyEqual(msgs.Messages{1}, systemPrompt);
6970
end
7071

@@ -76,27 +77,39 @@ function assistantMessageIsAdded(testCase)
7677
testCase.verifyEqual(msgs.Messages{1}, assistantPrompt);
7778
end
7879

79-
function assistantFunctionCallMessageIsAdded(testCase)
80+
function assistantToolCallMessageIsAdded(testCase)
8081
msgs = openAIMessages;
8182
functionName = "functionName";
8283
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
8384
funCall = struct("name", functionName, "arguments", args);
8485
toolCall = struct("id", "123", "type", "function", "function", funCall);
85-
functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
86-
functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls};
87-
msgs = addResponseMessage(msgs, functionCallPrompt);
88-
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
86+
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
87+
toolCallPrompt.tool_calls = {toolCall};
88+
msgs = addResponseMessage(msgs, toolCallPrompt);
89+
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
8990
end
9091

91-
function assistantFunctionCallMessageWithoutArgsIsAdded(testCase)
92+
function assistantToolCallMessageWithoutArgsIsAdded(testCase)
9293
msgs = openAIMessages;
9394
functionName = "functionName";
9495
funCall = struct("name", functionName, "arguments", "{}");
9596
toolCall = struct("id", "123", "type", "function", "function", funCall);
96-
functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
97-
functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls};
98-
msgs = addResponseMessage(msgs, functionCallPrompt);
99-
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
97+
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", []);
98+
toolCallPrompt.tool_calls = {toolCall};
99+
msgs = addResponseMessage(msgs, toolCallPrompt);
100+
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
101+
end
102+
103+
function assistantParallelToolCallMessageIsAdded(testCase)
104+
msgs = openAIMessages;
105+
functionName = "functionName";
106+
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
107+
funCall = struct("name", functionName, "arguments", args);
108+
toolCall = struct("id", "123", "type", "function", "function", funCall);
109+
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
110+
toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall];
111+
msgs = addResponseMessage(msgs, toolCallPrompt);
112+
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
100113
end
101114

102115
function messageGetsRemoved(testCase)
@@ -105,7 +118,7 @@ function messageGetsRemoved(testCase)
105118

106119
msgs = addSystemMessage(msgs, "name", "content");
107120
msgs = addUserMessage(msgs, "content");
108-
msgs = addFunctionMessage(msgs, "name", "content");
121+
msgs = addToolMessage(msgs, "123", "name", "content");
109122
sizeMsgs = length(msgs.Messages);
110123
% Message exists before removal
111124
msgToBeRemoved = msgs.Messages{idx};
@@ -121,7 +134,7 @@ function removalIdxCantBeLargerThanNumElements(testCase)
121134

122135
msgs = addSystemMessage(msgs, "name", "content");
123136
msgs = addUserMessage(msgs, "content");
124-
msgs = addFunctionMessage(msgs, "name", "content");
137+
msgs = addToolMessage(msgs, "123", "name", "content");
125138
sizeMsgs = length(msgs.Messages);
126139

127140
testCase.verifyError(@()removeMessage(msgs, sizeMsgs+1), "llms:mustBeValidIndex");
@@ -144,7 +157,7 @@ function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt)
144157

145158
function invalidInputsFunctionPrompt(testCase, InvalidInputsFunctionPrompt)
146159
msgs = openAIMessages;
147-
testCase.verifyError(@()addFunctionMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error);
160+
testCase.verifyError(@()addToolMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error);
148161
end
149162

150163
function invalidInputsRemove(testCase, InvalidRemoveMessage)
@@ -231,27 +244,27 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
231244
function invalidFunctionPrompt = iGetInvalidFunctionPrompt
232245
invalidFunctionPrompt = struct( ...
233246
"NonStringInputName", ...
234-
struct("Input", {{123, "content"}}, ...
247+
struct("Input", {{"123", 123, "content"}}, ...
235248
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
236249
...
237250
"NonStringInputContent", ...
238-
struct("Input", {{"name", 123}}, ...
251+
struct("Input", {{"123", "name", 123}}, ...
239252
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
240253
...
241254
"EmptytName", ...
242-
struct("Input", {{"", "content"}}, ...
255+
struct("Input", {{"123", "", "content"}}, ...
243256
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
244257
...
245258
"EmptytContent", ...
246-
struct("Input", {{"name", ""}}, ...
259+
struct("Input", {{"123", "name", ""}}, ...
247260
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
248261
...
249262
"NonScalarInputName", ...
250-
struct("Input", {{["name1" "name2"], "content"}}, ...
263+
struct("Input", {{"123", ["name1" "name2"], "content"}}, ...
251264
"Error", "MATLAB:validators:mustBeTextScalar"),...
252265
...
253266
"NonScalarInputContent", ...
254-
struct("Input", {{"name", ["content1", "content2"]}}, ...
267+
struct("Input", {{"123","name", ["content1", "content2"]}}, ...
255268
"Error", "MATLAB:validators:mustBeTextScalar"));
256269
end
257270

0 commit comments

Comments
 (0)