@@ -24,7 +24,7 @@ function differentInputTextAccepted(testCase, ValidTextInput)
24
24
testCase .verifyWarningFree(@()addSystemMessage(msgs , ValidTextInput , ValidTextInput ));
25
25
testCase .verifyWarningFree(@()addSystemMessage(msgs , ValidTextInput , ValidTextInput ));
26
26
testCase .verifyWarningFree(@()addUserMessage(msgs , ValidTextInput ));
27
- testCase .verifyWarningFree(@()addFunctionMessage (msgs , ValidTextInput , ValidTextInput ));
27
+ testCase .verifyWarningFree(@()addToolMessage (msgs , ValidTextInput , ValidTextInput , ValidTextInput ));
28
28
end
29
29
30
30
@@ -59,12 +59,13 @@ function userImageMessageIsAddedWithRemoteImg(testCase)
59
59
testCase .verifyWarningFree(@()addUserMessageWithImages(msgs , prompt , img ));
60
60
end
61
61
62
- function functionMessageIsAdded (testCase )
62
+ function toolMessageIsAdded (testCase )
63
63
prompt = " 20" ;
64
64
name = " sin" ;
65
+ id = " 123" ;
65
66
msgs = openAIMessages ;
66
- systemPrompt = struct(" role" , " function " , " name" , name , " content" , prompt );
67
- msgs = addFunctionMessage (msgs , name , prompt );
67
+ systemPrompt = struct(" tool_call_id " , id , " role" , " tool " , " name" , name , " content" , prompt );
68
+ msgs = addToolMessage (msgs , id , name , prompt );
68
69
testCase .verifyEqual(msgs.Messages{1 }, systemPrompt );
69
70
end
70
71
@@ -76,27 +77,39 @@ function assistantMessageIsAdded(testCase)
76
77
testCase .verifyEqual(msgs.Messages{1 }, assistantPrompt );
77
78
end
78
79
79
- function assistantFunctionCallMessageIsAdded (testCase )
80
+ function assistantToolCallMessageIsAdded (testCase )
80
81
msgs = openAIMessages ;
81
82
functionName = " functionName" ;
82
83
args = " {"" arg1"" : 1, "" arg2"" : 2, "" arg3"" : "" 3"" }" ;
83
84
funCall = struct(" name" , functionName , " arguments" , args );
84
85
toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
85
- functionCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , toolCall );
86
- functionCallPrompt .tool_calls = {functionCallPrompt . tool_calls };
87
- msgs = addResponseMessage(msgs , functionCallPrompt );
88
- testCase .verifyEqual(msgs.Messages{1 }, functionCallPrompt );
86
+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " , " tool_calls" , [] );
87
+ toolCallPrompt .tool_calls = {toolCall };
88
+ msgs = addResponseMessage(msgs , toolCallPrompt );
89
+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
89
90
end
90
91
91
- function assistantFunctionCallMessageWithoutArgsIsAdded (testCase )
92
+ function assistantToolCallMessageWithoutArgsIsAdded (testCase )
92
93
msgs = openAIMessages ;
93
94
functionName = " functionName" ;
94
95
funCall = struct(" name" , functionName , " arguments" , " {}" );
95
96
toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
96
- functionCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , toolCall );
97
- functionCallPrompt.tool_calls = {functionCallPrompt .tool_calls };
98
- msgs = addResponseMessage(msgs , functionCallPrompt );
99
- testCase .verifyEqual(msgs.Messages{1 }, functionCallPrompt );
97
+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , []);
98
+ toolCallPrompt.tool_calls = {toolCall };
99
+ msgs = addResponseMessage(msgs , toolCallPrompt );
100
+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
101
+ end
102
+
103
+ function assistantParallelToolCallMessageIsAdded(testCase )
104
+ msgs = openAIMessages ;
105
+ functionName = " functionName" ;
106
+ args = " {"" arg1"" : 1, "" arg2"" : 2, "" arg3"" : "" 3"" }" ;
107
+ funCall = struct(" name" , functionName , " arguments" , args );
108
+ toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
109
+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " , " tool_calls" , []);
110
+ toolCallPrompt.tool_calls = [toolCall ,toolCall ,toolCall ];
111
+ msgs = addResponseMessage(msgs , toolCallPrompt );
112
+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
100
113
end
101
114
102
115
function messageGetsRemoved(testCase )
@@ -105,7 +118,7 @@ function messageGetsRemoved(testCase)
105
118
106
119
msgs = addSystemMessage(msgs , " name" , " content" );
107
120
msgs = addUserMessage(msgs , " content" );
108
- msgs = addFunctionMessage (msgs , " name" , " content" );
121
+ msgs = addToolMessage (msgs , " 123 " , " name" , " content" );
109
122
sizeMsgs = length(msgs .Messages );
110
123
% Message exists before removal
111
124
msgToBeRemoved = msgs.Messages{idx };
@@ -121,7 +134,7 @@ function removalIdxCantBeLargerThanNumElements(testCase)
121
134
122
135
msgs = addSystemMessage(msgs , " name" , " content" );
123
136
msgs = addUserMessage(msgs , " content" );
124
- msgs = addFunctionMessage (msgs , " name" , " content" );
137
+ msgs = addToolMessage (msgs , " 123 " , " name" , " content" );
125
138
sizeMsgs = length(msgs .Messages );
126
139
127
140
testCase .verifyError(@()removeMessage(msgs , sizeMsgs + 1 ), " llms:mustBeValidIndex" );
@@ -144,7 +157,7 @@ function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt)
144
157
145
158
function invalidInputsFunctionPrompt(testCase , InvalidInputsFunctionPrompt )
146
159
msgs = openAIMessages ;
147
- testCase .verifyError(@()addFunctionMessage (msgs ,InvalidInputsFunctionPrompt.Input{: }), InvalidInputsFunctionPrompt .Error );
160
+ testCase .verifyError(@()addToolMessage (msgs ,InvalidInputsFunctionPrompt.Input{: }), InvalidInputsFunctionPrompt .Error );
148
161
end
149
162
150
163
function invalidInputsRemove(testCase , InvalidRemoveMessage )
@@ -231,27 +244,27 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
231
244
function invalidFunctionPrompt = iGetInvalidFunctionPrompt
232
245
invalidFunctionPrompt = struct( ...
233
246
" NonStringInputName" , ...
234
- struct(" Input" , {{123 , " content" }}, ...
247
+ struct(" Input" , {{" 123 " , 123 , " content" }}, ...
235
248
" Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
236
249
...
237
250
" NonStringInputContent" , ...
238
- struct(" Input" , {{" name" , 123 }}, ...
251
+ struct(" Input" , {{" 123 " , " name" , 123 }}, ...
239
252
" Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
240
253
...
241
254
" EmptytName" , ...
242
- struct(" Input" , {{" " , " content" }}, ...
255
+ struct(" Input" , {{" 123 " , " " , " content" }}, ...
243
256
" Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
244
257
...
245
258
" EmptytContent" , ...
246
- struct(" Input" , {{" name" , " " }}, ...
259
+ struct(" Input" , {{" 123 " , " name" , " " }}, ...
247
260
" Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
248
261
...
249
262
" NonScalarInputName" , ...
250
- struct(" Input" , {{[" name1" " name2" ], " content" }}, ...
263
+ struct(" Input" , {{" 123 " , [" name1" " name2" ], " content" }}, ...
251
264
" Error" , " MATLAB:validators:mustBeTextScalar" ),...
252
265
...
253
266
" NonScalarInputContent" , ...
254
- struct(" Input" , {{" name" , [" content1" , " content2" ]}}, ...
267
+ struct(" Input" , {{" 123 " , " name" , [" content1" , " content2" ]}}, ...
255
268
" Error" , " MATLAB:validators:mustBeTextScalar" ));
256
269
end
257
270
0 commit comments