Skip to content

Commit a32e681

Browse files
author
Angel Vega Alvarez
committed
Adding support to Azure API
1 parent a038630 commit a32e681

File tree

11 files changed

+919
-109
lines changed

11 files changed

+919
-109
lines changed

+llms/+internal/callAzureChatAPI.m

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp)
2+
%callOpenAIChatAPI Calls the openAI chat completions API.
3+
%
4+
% MESSAGES and FUNCTIONS should be structs matching the json format
5+
% required by the OpenAI Chat Completions API.
6+
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
7+
%
8+
% Currently, the supported NVP are, including the equivalent name in the API:
9+
% - ToolChoice (tool_choice)
10+
% - Temperature (temperature)
11+
% - TopProbabilityMass (top_p)
12+
% - NumCompletions (n)
13+
% - StopSequences (stop)
14+
% - MaxNumTokens (max_tokens)
15+
% - PresencePenalty (presence_penalty)
16+
% - FrequencyPenalty (frequence_penalty)
17+
% - ResponseFormat (response_format)
18+
% - Seed (seed)
19+
% - ApiKey
20+
% - TimeOut
21+
% - StreamFun
22+
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
23+
%
24+
% Example
25+
%
26+
% % Create messages struct
27+
% messages = {struct("role", "system",...
28+
% "content", "You are a helpful assistant");
29+
% struct("role", "user", ...
30+
% "content", "What is the edit distance between hi and hello?")};
31+
%
32+
% % Create functions struct
33+
% functions = {struct("name", "editDistance", ...
34+
% "description", "Find edit distance between two strings or documents.", ...
35+
% "parameters", struct( ...
36+
% "type", "object", ...
37+
% "properties", struct(...
38+
% "str1", struct(...
39+
% "description", "Source string.", ...
40+
% "type", "string"),...
41+
% "str2", struct(...
42+
% "description", "Target string.", ...
43+
% "type", "string")),...
44+
% "required", ["str1", "str2"]))};
45+
%
46+
% % Define your API key
47+
% apiKey = "your-api-key-here"
48+
%
49+
% % Send a request
50+
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
51+
52+
% Copyright 2023-2024 The MathWorks, Inc.
53+
54+
arguments
55+
resourceName
56+
deploymentID
57+
messages
58+
functions
59+
nvp.ToolChoice = []
60+
nvp.APIVersion = "2023-05-15"
61+
nvp.Temperature = 1
62+
nvp.TopProbabilityMass = 1
63+
nvp.NumCompletions = 1
64+
nvp.StopSequences = []
65+
nvp.MaxNumTokens = inf
66+
nvp.PresencePenalty = 0
67+
nvp.FrequencyPenalty = 0
68+
nvp.ResponseFormat = "text"
69+
nvp.Seed = []
70+
nvp.ApiKey = ""
71+
nvp.TimeOut = 10
72+
nvp.StreamFun = []
73+
end
74+
75+
END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;
76+
77+
parameters = buildParametersCall(messages, functions, nvp);
78+
79+
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
80+
81+
% If call errors, "choices" will not be part of response.Body.Data, instead
82+
% we get response.Body.Data.error
83+
if response.StatusCode=="OK"
84+
% Outputs the first generation
85+
if isempty(nvp.StreamFun)
86+
message = response.Body.Data.choices(1).message;
87+
else
88+
message = struct("role", "assistant", ...
89+
"content", streamedText);
90+
end
91+
if isfield(message, "tool_choice")
92+
text = "";
93+
else
94+
text = string(message.content);
95+
end
96+
else
97+
text = "";
98+
message = struct();
99+
end
100+
end
101+
102+
function parameters = buildParametersCall(messages, functions, nvp)
103+
% Builds a struct in the format that is expected by the API, combining
104+
% MESSAGES, FUNCTIONS and parameters in NVP.
105+
106+
parameters = struct();
107+
parameters.messages = messages;
108+
109+
parameters.stream = ~isempty(nvp.StreamFun);
110+
111+
parameters.tools = functions;
112+
113+
parameters.tool_choice = nvp.ToolChoice;
114+
115+
if ~isempty(nvp.Seed)
116+
parameters.seed = nvp.Seed;
117+
end
118+
119+
dict = mapNVPToParameters;
120+
121+
nvpOptions = keys(dict);
122+
for opt = nvpOptions.'
123+
if isfield(nvp, opt)
124+
parameters.(dict(opt)) = nvp.(opt);
125+
end
126+
end
127+
end
128+
129+
function dict = mapNVPToParameters()
130+
dict = dictionary();
131+
dict("Temperature") = "temperature";
132+
dict("TopProbabilityMass") = "top_p";
133+
dict("NumCompletions") = "n";
134+
dict("StopSequences") = "stop";
135+
dict("MaxNumTokens") = "max_tokens";
136+
dict("PresencePenalty") = "presence_penalty";
137+
dict("FrequencyPenalty") = "frequency_penalty";
138+
end

+llms/+internal/textGenerator.m

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
classdef (Abstract) textGenerator
2+
3+
properties
4+
%TEMPERATURE Temperature of generation.
5+
Temperature
6+
7+
%TOPPROBABILITYMASS Top probability mass to consider for generation.
8+
TopProbabilityMass
9+
10+
%STOPSEQUENCES Sequences to stop the generation of tokens.
11+
StopSequences
12+
13+
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
14+
PresencePenalty
15+
16+
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
17+
FrequencyPenalty
18+
end
19+
20+
properties (SetAccess=protected)
21+
%TIMEOUT Connection timeout in seconds (default 10 secs)
22+
TimeOut
23+
24+
%FUNCTIONNAMES Names of the functions that the model can request calls
25+
FunctionNames
26+
27+
%SYSTEMPROMPT System prompt.
28+
SystemPrompt = []
29+
30+
%RESPONSEFORMAT Response format, "text" or "json"
31+
ResponseFormat
32+
end
33+
34+
properties (Access=protected)
35+
Tools
36+
FunctionsStruct
37+
ApiKey
38+
StreamFun
39+
end
40+
41+
42+
methods
43+
function this = set.Temperature(this, temperature)
44+
arguments
45+
this
46+
temperature
47+
end
48+
llms.utils.mustBeValidTemperature(temperature);
49+
this.Temperature = temperature;
50+
end
51+
52+
function this = set.TopProbabilityMass(this,topP)
53+
arguments
54+
this
55+
topP
56+
end
57+
llms.utils.mustBeValidTopP(topP);
58+
this.TopProbabilityMass = topP;
59+
end
60+
61+
function this = set.StopSequences(this,stop)
62+
arguments
63+
this
64+
stop
65+
end
66+
llms.utils.mustBeValidStop(stop);
67+
this.StopSequences = stop;
68+
end
69+
70+
function this = set.PresencePenalty(this,penalty)
71+
arguments
72+
this
73+
penalty
74+
end
75+
llms.utils.mustBeValidPenalty(penalty)
76+
this.PresencePenalty = penalty;
77+
end
78+
79+
function this = set.FrequencyPenalty(this,penalty)
80+
arguments
81+
this
82+
penalty
83+
end
84+
llms.utils.mustBeValidPenalty(penalty)
85+
this.FrequencyPenalty = penalty;
86+
end
87+
88+
end
89+
90+
end

+llms/+utils/errorMessageCatalog.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@
5353
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
5454
catalog("llms:pngExpected") = "Argument must be a PNG image.";
5555
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
56+
catalog("llms:invalidOptionsForOpenAIBackEnd") = "The parameters Resource Name, Deployment ID and API Version are not compatible with OpenAI.";
57+
catalog("llms:invalidOptionsForAzureBackEnd") = "The parameter Model Name is not compatible with Azure.";
58+
5659
end

+llms/+utils/mustBeValidPenalty.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function mustBeValidPenalty(value)
2+
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2})
3+
end

+llms/+utils/mustBeValidStop.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function mustBeValidStop(value)
2+
if ~isempty(value)
3+
mustBeVector(value);
4+
mustBeNonzeroLengthText(value);
5+
% This restriction is set by the OpenAI API
6+
if numel(value)>4
7+
error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements"));
8+
end
9+
end
10+
end

+llms/+utils/mustBeValidTemperature.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function mustBeValidTemperature(value)
2+
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2})
3+
end

+llms/+utils/mustBeValidTopP.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function mustBeValidTopP(value)
2+
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1})
3+
end

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,35 @@ messages = addUserMessageWithImages(messages,"What is in the image?",image_path)
288288
% Should output the description of the image
289289
```
290290
291+
## Establishing a connection to Chat Completions API using Azure®
292+
293+
If you would like to connect MATLAB to Chat Completions API via Azure® instead of directly with OpenAI, you will have to create an `azureChat` object.
294+
However, you first need to obtain, in addition to the Azure API keys, your Azure OpenAI Resource.
295+
296+
In order to create the chat assistant, you must specify your Azure OpenAI Resource and the LLM you want to use:
297+
```matlab
298+
chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant");
299+
```
300+
301+
The `azureChat` object also allows to specify additional options in the same way as the `openAIChat` object.
302+
However, the `ModelName` option is not available due to the fact that the name of the LLM is already specified when creating the chat assistant.
303+
304+
On the other hand, the `azureChat` object offers an additional option that allows you to set the API version that you want to use for the operation.
305+
306+
After establishing your connection with Azure, you can continue using the `generate` function and other objects in the same way as if you had established a connection directly with OpenAI:
307+
```matlab
308+
% Initialize the Azure Chat object, passing a system prompt and specifying the API version
309+
chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant", APIVersion="2023-12-01-preview");
310+
311+
% Create an openAIMessages object to start the conversation history
312+
history = openAIMessages;
313+
314+
% Ask your question and store it in the history, create the response using the generate method, and store the response in the history
315+
history = addUserMessage(history,"What is an eigenvalue?");
316+
[txt, response] = generate(chat, history)
317+
history = addResponseMessage(history, response);
318+
```
319+
291320
### Obtaining embeddings
292321
293322
You can extract embeddings from your text with OpenAI using the function `extractOpenAIEmbeddings` as follows:

0 commit comments

Comments
 (0)