Skip to content

Commit 626ddc6

Browse files
committed
Move model capabilitiy verification out of openAIChat.m, for maintainability
1 parent f4f97b9 commit 626ddc6

File tree

5 files changed

+217
-24
lines changed

5 files changed

+217
-24
lines changed

+llms/+openai/models.m

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function models = models
2+
%MODELS - supported OpenAI models
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
models = [...
6+
"gpt-4o","gpt-4o-2024-05-13",...
7+
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
8+
"gpt-4","gpt-4-0613", ...
9+
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
10+
"gpt-3.5-turbo-1106",...
11+
];
12+
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function validateMessageSupported(message, model);
2+
%validateMessageSupported - check that message is supported by model
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
6+
% only certain models support image generation
7+
if iscell(message.content) && any(cellfun(@(x) isfield(x,"image_url"), message.content))
8+
if ~ismember(model,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
9+
error("llms:invalidContentTypeForModel", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", model));
11+
end
12+
end
13+
end
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function validateResponseFormat(format,model)
2+
%validateResponseFormat - validate requested response format is available for selected model
3+
% Not all OpenAI models support JSON output
4+
5+
% Copyright 2024 The MathWorks, Inc.
6+
7+
if format == "json"
8+
if ismember(model,["gpt-4","gpt-4-0613"])
9+
error("llms:invalidOptionAndValueForModel", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model));
11+
else
12+
warning("llms:warningJsonInstruction", ...
13+
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
14+
end
15+
end
16+
end

openAIChat.m

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,7 @@
114114
arguments
115115
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
116116
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
117-
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,[...
118-
"gpt-4o","gpt-4o-2024-05-13",...
119-
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
120-
"gpt-4","gpt-4-0613", ...
121-
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
122-
"gpt-3.5-turbo-1106",...
123-
])} = "gpt-3.5-turbo"
117+
nvp.ModelName (1,1) string {mustBeModel} = "gpt-3.5-turbo"
124118
nvp.Temperature {mustBeValidTemperature} = 1
125119
nvp.TopProbabilityMass {mustBeValidTopP} = 1
126120
nvp.StopSequences {mustBeValidStop} = {}
@@ -160,16 +154,8 @@
160154
this.StopSequences = nvp.StopSequences;
161155

162156
% ResponseFormat is only supported in the latest models only
163-
if nvp.ResponseFormat == "json"
164-
if ismember(this.ModelName,["gpt-4","gpt-4-0613"])
165-
error("llms:invalidOptionAndValueForModel", ...
166-
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", this.ModelName));
167-
else
168-
warning("llms:warningJsonInstruction", ...
169-
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
170-
end
171-
172-
end
157+
llms.openai.validateResponseFormat(nvp.ResponseFormat, this.ModelName);
158+
this.ResponseFormat = nvp.ResponseFormat;
173159

174160
this.PresencePenalty = nvp.PresencePenalty;
175161
this.FrequencyPenalty = nvp.FrequencyPenalty;
@@ -219,12 +205,7 @@
219205
messagesStruct = messages.Messages;
220206
end
221207

222-
if iscell(messagesStruct{end}.content) && any(cellfun(@(x) isfield(x,"image_url"), messagesStruct{end}.content))
223-
if ~ismember(this.ModelName,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
224-
error("llms:invalidContentTypeForModel", ...
225-
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", this.ModelName));
226-
end
227-
end
208+
llms.openai.validateMessageSupported(messagesStruct{end}, model);
228209

229210
if ~isempty(this.SystemPrompt)
230211
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
@@ -334,3 +315,7 @@ function mustBeIntegerOrEmpty(value)
334315
mustBeInteger(value)
335316
end
336317
end
318+
319+
function mustBeModel(model)
320+
mustBeMember(model,llms.openai.models);
321+
end

tests/topenAIChat.m

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ function saveEnvVar(testCase)
1616
end
1717

1818
properties(TestParameter)
19+
ValidConstructorInput = iGetValidConstructorInput();
1920
InvalidConstructorInput = iGetInvalidConstructorInput();
2021
InvalidGenerateInput = iGetInvalidGenerateInput();
2122
InvalidValuesSetters = iGetInvalidValuesSetters();
@@ -65,6 +66,21 @@ function constructChatWithAllNVP(testCase)
6566
testCase.verifyEqual(chat.PresencePenalty, presenceP);
6667
end
6768

69+
function validConstructorCalls(testCase,ValidConstructorInput)
70+
if isempty(ValidConstructorInput.ExpectedWarning)
71+
chat = testCase.verifyWarningFree(...
72+
@() openAIChat(ValidConstructorInput.Input{:}));
73+
else
74+
chat = testCase.verifyWarning(...
75+
@() openAIChat(ValidConstructorInput.Input{:}), ...
76+
ValidConstructorInput.ExpectedWarning);
77+
end
78+
properties = ValidConstructorInput.VerifyProperties;
79+
for prop=string(fieldnames(properties)).'
80+
testCase.verifyEqual(chat.(prop),properties.(prop),"Property " + prop);
81+
end
82+
end
83+
6884
function verySmallTimeOutErrors(testCase)
6985
chat = openAIChat(TimeOut=0.0001, ApiKey="false-key");
7086

@@ -126,7 +142,6 @@ function noStopSequencesNoMaxNumTokens(testCase)
126142
end
127143

128144
function createOpenAIChatWithStreamFunc(testCase)
129-
130145
function seen = sf(str)
131146
persistent data;
132147
if isempty(data)
@@ -275,6 +290,158 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase)
275290
"Error", "MATLAB:notGreaterEqual"));
276291
end
277292

293+
function validConstructorInput = iGetValidConstructorInput()
294+
% while it is valid to provide the key via an environment variable,
295+
% this test set does not use that, for easier setup
296+
validFunction = openAIFunction("funName");
297+
validConstructorInput = struct( ...
298+
"JustKey", struct( ...
299+
"Input",{{"ApiKey","this-is-not-a-real-key"}}, ...
300+
"ExpectedWarning", '', ...
301+
"VerifyProperties", struct( ...
302+
"Temperature", {1}, ...
303+
"TopProbabilityMass", {1}, ...
304+
"StopSequences", {{}}, ...
305+
"PresencePenalty", {0}, ...
306+
"FrequencyPenalty", {0}, ...
307+
"TimeOut", {10}, ...
308+
"FunctionNames", {[]}, ...
309+
"ModelName", {"gpt-3.5-turbo"}, ...
310+
"SystemPrompt", {[]}, ...
311+
"ResponseFormat", {"text"} ...
312+
) ...
313+
), ...
314+
"SystemPrompt", struct( ...
315+
"Input",{{"system prompt","ApiKey","this-is-not-a-real-key"}}, ...
316+
"ExpectedWarning", '', ...
317+
"VerifyProperties", struct( ...
318+
"Temperature", {1}, ...
319+
"TopProbabilityMass", {1}, ...
320+
"StopSequences", {{}}, ...
321+
"PresencePenalty", {0}, ...
322+
"FrequencyPenalty", {0}, ...
323+
"TimeOut", {10}, ...
324+
"FunctionNames", {[]}, ...
325+
"ModelName", {"gpt-3.5-turbo"}, ...
326+
"SystemPrompt", {{struct("role","system","content","system prompt")}}, ...
327+
"ResponseFormat", {"text"} ...
328+
) ...
329+
), ...
330+
"Temperature", struct( ...
331+
"Input",{{"ApiKey","this-is-not-a-real-key","Temperature",2}}, ...
332+
"ExpectedWarning", '', ...
333+
"VerifyProperties", struct( ...
334+
"Temperature", {2}, ...
335+
"TopProbabilityMass", {1}, ...
336+
"StopSequences", {{}}, ...
337+
"PresencePenalty", {0}, ...
338+
"FrequencyPenalty", {0}, ...
339+
"TimeOut", {10}, ...
340+
"FunctionNames", {[]}, ...
341+
"ModelName", {"gpt-3.5-turbo"}, ...
342+
"SystemPrompt", {[]}, ...
343+
"ResponseFormat", {"text"} ...
344+
) ...
345+
), ...
346+
"TopProbabilityMass", struct( ...
347+
"Input",{{"ApiKey","this-is-not-a-real-key","TopProbabilityMass",0.2}}, ...
348+
"ExpectedWarning", '', ...
349+
"VerifyProperties", struct( ...
350+
"Temperature", {1}, ...
351+
"TopProbabilityMass", {0.2}, ...
352+
"StopSequences", {{}}, ...
353+
"PresencePenalty", {0}, ...
354+
"FrequencyPenalty", {0}, ...
355+
"TimeOut", {10}, ...
356+
"FunctionNames", {[]}, ...
357+
"ModelName", {"gpt-3.5-turbo"}, ...
358+
"SystemPrompt", {[]}, ...
359+
"ResponseFormat", {"text"} ...
360+
) ...
361+
), ...
362+
"StopSequences", struct( ...
363+
"Input",{{"ApiKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ...
364+
"ExpectedWarning", '', ...
365+
"VerifyProperties", struct( ...
366+
"Temperature", {1}, ...
367+
"TopProbabilityMass", {1}, ...
368+
"StopSequences", {["foo","bar"]}, ...
369+
"PresencePenalty", {0}, ...
370+
"FrequencyPenalty", {0}, ...
371+
"TimeOut", {10}, ...
372+
"FunctionNames", {[]}, ...
373+
"ModelName", {"gpt-3.5-turbo"}, ...
374+
"SystemPrompt", {[]}, ...
375+
"ResponseFormat", {"text"} ...
376+
) ...
377+
), ...
378+
"PresencePenalty", struct( ...
379+
"Input",{{"ApiKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ...
380+
"ExpectedWarning", '', ...
381+
"VerifyProperties", struct( ...
382+
"Temperature", {1}, ...
383+
"TopProbabilityMass", {1}, ...
384+
"StopSequences", {{}}, ...
385+
"PresencePenalty", {0.1}, ...
386+
"FrequencyPenalty", {0}, ...
387+
"TimeOut", {10}, ...
388+
"FunctionNames", {[]}, ...
389+
"ModelName", {"gpt-3.5-turbo"}, ...
390+
"SystemPrompt", {[]}, ...
391+
"ResponseFormat", {"text"} ...
392+
) ...
393+
), ...
394+
"FrequencyPenalty", struct( ...
395+
"Input",{{"ApiKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ...
396+
"ExpectedWarning", '', ...
397+
"VerifyProperties", struct( ...
398+
"Temperature", {1}, ...
399+
"TopProbabilityMass", {1}, ...
400+
"StopSequences", {{}}, ...
401+
"PresencePenalty", {0}, ...
402+
"FrequencyPenalty", {0.1}, ...
403+
"TimeOut", {10}, ...
404+
"FunctionNames", {[]}, ...
405+
"ModelName", {"gpt-3.5-turbo"}, ...
406+
"SystemPrompt", {[]}, ...
407+
"ResponseFormat", {"text"} ...
408+
) ...
409+
), ...
410+
"TimeOut", struct( ...
411+
"Input",{{"ApiKey","this-is-not-a-real-key","TimeOut",0.1}}, ...
412+
"ExpectedWarning", '', ...
413+
"VerifyProperties", struct( ...
414+
"Temperature", {1}, ...
415+
"TopProbabilityMass", {1}, ...
416+
"StopSequences", {{}}, ...
417+
"PresencePenalty", {0}, ...
418+
"FrequencyPenalty", {0}, ...
419+
"TimeOut", {0.1}, ...
420+
"FunctionNames", {[]}, ...
421+
"ModelName", {"gpt-3.5-turbo"}, ...
422+
"SystemPrompt", {[]}, ...
423+
"ResponseFormat", {"text"} ...
424+
) ...
425+
), ...
426+
"ResponseFormat", struct( ...
427+
"Input",{{"ApiKey","this-is-not-a-real-key","ResponseFormat","json"}}, ...
428+
"ExpectedWarning", "llms:warningJsonInstruction", ...
429+
"VerifyProperties", struct( ...
430+
"Temperature", {1}, ...
431+
"TopProbabilityMass", {1}, ...
432+
"StopSequences", {{}}, ...
433+
"PresencePenalty", {0}, ...
434+
"FrequencyPenalty", {0}, ...
435+
"TimeOut", {10}, ...
436+
"FunctionNames", {[]}, ...
437+
"ModelName", {"gpt-3.5-turbo"}, ...
438+
"SystemPrompt", {[]}, ...
439+
"ResponseFormat", {"json"} ...
440+
) ...
441+
) ...
442+
);
443+
end
444+
278445
function invalidConstructorInput = iGetInvalidConstructorInput()
279446
validFunction = openAIFunction("funName");
280447
invalidConstructorInput = struct( ...

0 commit comments

Comments
 (0)