Skip to content

Commit ccd6961

Browse files
committed
merge main
2 parents a32e681 + 05ac9a9 commit ccd6961

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+914
-363
lines changed

+llms/+internal/callAzureChatAPI.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
2323
%
2424
% Example
25-
%
25+
%
2626
% % Create messages struct
2727
% messages = {struct("role", "system",...
2828
% "content", "You are a helpful assistant");

+llms/+internal/callOpenAIChatAPI.m

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,20 @@
8484
if isempty(nvp.StreamFun)
8585
message = response.Body.Data.choices(1).message;
8686
else
87-
message = struct("role", "assistant", ...
88-
"content", streamedText);
87+
pat = '{"' + wildcardPattern + '":';
88+
if contains(streamedText,pat)
89+
s = jsondecode(streamedText);
90+
if contains(s.function.arguments,pat)
91+
prompt = jsondecode(s.function.arguments);
92+
s.function.arguments = prompt;
93+
end
94+
message = struct("role", "assistant", ...
95+
"content",[], ...
96+
"tool_calls",jsondecode(streamedText));
97+
else
98+
message = struct("role", "assistant", ...
99+
"content", streamedText);
100+
end
89101
end
90102
if isfield(message, "tool_choice")
91103
text = "";
@@ -107,18 +119,16 @@
107119

108120
parameters.stream = ~isempty(nvp.StreamFun);
109121

110-
if ~isempty(functions) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
122+
if ~isempty(functions)
111123
parameters.tools = functions;
112124
end
113125

114-
if ~isempty(nvp.ToolChoice) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
126+
if ~isempty(nvp.ToolChoice)
115127
parameters.tool_choice = nvp.ToolChoice;
116128
end
117129

118-
if ismember(nvp.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
119-
if strcmp(nvp.ResponseFormat,"json")
120-
parameters.response_format = struct('type','json_object');
121-
end
130+
if strcmp(nvp.ResponseFormat,"json")
131+
parameters.response_format = struct('type','json_object');
122132
end
123133

124134
if ~isempty(nvp.Seed)
@@ -130,15 +140,21 @@
130140
dict = mapNVPToParameters;
131141

132142
nvpOptions = keys(dict);
133-
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
134-
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
135-
end
136143

137144
for opt = nvpOptions.'
138145
if isfield(nvp, opt)
139146
parameters.(dict(opt)) = nvp.(opt);
140147
end
141148
end
149+
150+
if isempty(nvp.StopSequences)
151+
parameters = rmfield(parameters,"stop");
152+
end
153+
154+
if nvp.MaxNumTokens == Inf
155+
parameters = rmfield(parameters,"max_tokens");
156+
end
157+
142158
end
143159

144160
function dict = mapNVPToParameters()

+llms/+internal/textGenerator.m

Lines changed: 14 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
classdef (Abstract) textGenerator
2-
3-
properties
2+
%
3+
4+
% Copyright 2023-2024 The MathWorks, Inc.
5+
6+
properties
47
%TEMPERATURE Temperature of generation.
5-
Temperature
8+
Temperature {llms.utils.mustBeValidTemperature} = 1
69

710
%TOPPROBABILITYMASS Top probability mass to consider for generation.
8-
TopProbabilityMass
11+
TopProbabilityMass {llms.utils.mustBeValidTopP} = 1
912

1013
%STOPSEQUENCES Sequences to stop the generation of tokens.
11-
StopSequences
14+
StopSequences {llms.utils.mustBeValidStop} = {}
1215

1316
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
14-
PresencePenalty
17+
PresencePenalty {llms.utils.mustBeValidPenalty} = 0
1518

1619
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
17-
FrequencyPenalty
20+
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0
1821
end
19-
22+
2023
properties (SetAccess=protected)
2124
%TIMEOUT Connection timeout in seconds (default 10 secs)
2225
TimeOut
@@ -27,64 +30,14 @@
2730
%SYSTEMPROMPT System prompt.
2831
SystemPrompt = []
2932

30-
%RESPONSEFORMAT Response format, "text" or "json"
33+
%RESPONSEFORMAT Response format, "text" or "json"
3134
ResponseFormat
3235
end
33-
36+
3437
properties (Access=protected)
3538
Tools
3639
FunctionsStruct
3740
ApiKey
3841
StreamFun
3942
end
40-
41-
42-
methods
43-
function this = set.Temperature(this, temperature)
44-
arguments
45-
this
46-
temperature
47-
end
48-
llms.utils.mustBeValidTemperature(temperature);
49-
this.Temperature = temperature;
50-
end
51-
52-
function this = set.TopProbabilityMass(this,topP)
53-
arguments
54-
this
55-
topP
56-
end
57-
llms.utils.mustBeValidTopP(topP);
58-
this.TopProbabilityMass = topP;
59-
end
60-
61-
function this = set.StopSequences(this,stop)
62-
arguments
63-
this
64-
stop
65-
end
66-
llms.utils.mustBeValidStop(stop);
67-
this.StopSequences = stop;
68-
end
69-
70-
function this = set.PresencePenalty(this,penalty)
71-
arguments
72-
this
73-
penalty
74-
end
75-
llms.utils.mustBeValidPenalty(penalty)
76-
this.PresencePenalty = penalty;
77-
end
78-
79-
function this = set.FrequencyPenalty(this,penalty)
80-
arguments
81-
this
82-
penalty
83-
end
84-
llms.utils.mustBeValidPenalty(penalty)
85-
this.FrequencyPenalty = penalty;
86-
end
87-
88-
end
89-
90-
end
43+
end

+llms/+openai/models.m

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function models = models
2+
%MODELS - supported OpenAI models
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
models = [...
6+
"gpt-4o","gpt-4o-2024-05-13",...
7+
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
8+
"gpt-4","gpt-4-0613", ...
9+
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
10+
"gpt-3.5-turbo-1106",...
11+
];
12+
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function validateMessageSupported(message, model);
2+
%validateMessageSupported - check that message is supported by model
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
6+
% only certain models support image generation
7+
if iscell(message.content) && any(cellfun(@(x) isfield(x,"image_url"), message.content))
8+
if ~ismember(model,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
9+
error("llms:invalidContentTypeForModel", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", model));
11+
end
12+
end
13+
end
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function validateResponseFormat(format,model)
2+
%validateResponseFormat - validate requested response format is available for selected model
3+
% Not all OpenAI models support JSON output
4+
5+
% Copyright 2024 The MathWorks, Inc.
6+
7+
if format == "json"
8+
if ismember(model,["gpt-4","gpt-4-0613"])
9+
error("llms:invalidOptionAndValueForModel", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model));
11+
else
12+
warning("llms:warningJsonInstruction", ...
13+
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
14+
end
15+
end
16+
end

+llms/+stream/responseStreamer.m

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,43 @@
3636
str = erase(str,"data: ");
3737

3838
for i = 1:length(str)
39-
json = jsondecode(str{i});
40-
if strcmp(json.choices.finish_reason,'stop')
39+
if strcmp(str{i},'[DONE]')
4140
stop = true;
4241
return
4342
else
44-
txt = json.choices.delta.content;
45-
this.StreamFun(txt);
46-
this.ResponseText = [this.ResponseText txt];
43+
try
44+
json = jsondecode(str{i});
45+
catch ME
46+
errID = 'llms:stream:responseStreamer:InvalidInput';
47+
msg = "Input does not have the expected json format. " + str{i};
48+
ME = MException(errID,msg);
49+
throw(ME)
50+
end
51+
if ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"])
52+
stop = true;
53+
return
54+
else
55+
if isfield(json.choices.delta,"tool_calls")
56+
if isfield(json.choices.delta.tool_calls,"id")
57+
id = json.choices.delta.tool_calls.id;
58+
type = json.choices.delta.tool_calls.type;
59+
fcn = json.choices.delta.tool_calls.function;
60+
s = struct('id',id,'type',type,'function',fcn);
61+
txt = jsonencode(s);
62+
else
63+
s = jsondecode(this.ResponseText);
64+
args = json.choices.delta.tool_calls.function.arguments;
65+
s.function.arguments = [s.function.arguments args];
66+
txt = jsonencode(s);
67+
end
68+
this.StreamFun('');
69+
this.ResponseText = txt;
70+
else
71+
txt = json.choices.delta.content;
72+
this.StreamFun(txt);
73+
this.ResponseText = [this.ResponseText txt];
74+
end
75+
end
4776
end
4877
end
4978
end

+llms/+utils/errorMessageCatalog.m

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@
4949
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
5050
catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for ModelName '{3}'";
5151
catalog("llms:invalidOptionForModel") = "{1} is not supported for ModelName '{2}'";
52+
catalog("llms:invalidContentTypeForModel") = "{1} is not supported for ModelName '{2}'";
5253
catalog("llms:functionNotAvailableForModel") = "This function is not supported for ModelName '{1}'";
5354
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
5455
catalog("llms:pngExpected") = "Argument must be a PNG image.";
5556
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
5657
catalog("llms:invalidOptionsForOpenAIBackEnd") = "The parameters Resource Name, Deployment ID and API Version are not compatible with OpenAI.";
5758
catalog("llms:invalidOptionsForAzureBackEnd") = "The parameter Model Name is not compatible with Azure.";
58-
59+
catalog("llms:apiReturnedError") = "OpenAI API Error: {1}";
60+
catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}.";
5961
end

+llms/+utils/isUnique.m

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@
66
% Copyright 2023 The MathWorks, Inc.
77
tf = numel(values)==numel(unique(values));
88
end
9-
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
function mustBeNonzeroLengthTextScalar(content)
2+
% This function is undocumented and will change in a future release
3+
4+
% Simple function to check if value is empty or text scalar
5+
6+
% Copyright 2024 The MathWorks, Inc.
27
mustBeNonzeroLengthText(content)
38
mustBeTextScalar(content)
4-
end
9+
end

+llms/+utils/mustBeTextOrEmpty.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ function mustBeTextOrEmpty(value)
77
if ~isempty(value)
88
mustBeTextScalar(value)
99
end
10-
end
10+
end

+llms/+utils/mustBeValidPenalty.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
function mustBeValidPenalty(value)
2+
% This function is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
25
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2})
3-
end
6+
end

+llms/+utils/mustBeValidStop.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
function mustBeValidStop(value)
2+
% This function is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
25
if ~isempty(value)
36
mustBeVector(value);
47
mustBeNonzeroLengthText(value);
@@ -7,4 +10,4 @@ function mustBeValidStop(value)
710
error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements"));
811
end
912
end
10-
end
13+
end

+llms/+utils/mustBeValidTemperature.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
function mustBeValidTemperature(value)
2+
% This function is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
25
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2})
3-
end
6+
end

+llms/+utils/mustBeValidTopP.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
function mustBeValidTopP(value)
2+
% This function is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
25
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1})
3-
end
6+
end

.github/workflows/ci.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Run MATLAB Tests on GitHub-Hosted Runner
2+
on: [push]
3+
jobs:
4+
test:
5+
name: Run MATLAB Tests and Generate Artifacts
6+
runs-on: ubuntu-latest
7+
steps:
8+
- name: Check out repository
9+
uses: actions/checkout@v4
10+
- name: Set up MATLAB
11+
uses: matlab-actions/setup-matlab@v2
12+
with:
13+
products: Text_Analytics_Toolbox
14+
cache: true
15+
- name: Run tests and generate artifacts
16+
env:
17+
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
18+
uses: matlab-actions/run-tests@v2
19+
with:
20+
test-results-junit: test-results/results.xml
21+
code-coverage-cobertura: code-coverage/coverage.xml
22+
source-folder: .
23+
- name: Upload coverage reports to Codecov
24+
uses: codecov/codecov-action@v4
25+
with:
26+
token: ${{ secrets.CODECOV_TOKEN }}
27+
slug: matlab-deep-learning/llms-with-matlab

0 commit comments

Comments
 (0)