Skip to content

Commit 8bd236b

Browse files
committed
get basic Azure connection working
1 parent 26b1272 commit 8bd236b

File tree

4 files changed

+41
-38
lines changed

4 files changed

+41
-38
lines changed

+llms/+internal/callAzureChatAPI.m

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp)
1+
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp)
22
%callOpenAIChatAPI Calls the openAI chat completions API.
33
%
44
% MESSAGES and FUNCTIONS should be structs matching the json format
@@ -52,7 +52,7 @@
5252
% Copyright 2023-2024 The MathWorks, Inc.
5353

5454
arguments
55-
resourceName
55+
endpoint
5656
deploymentID
5757
messages
5858
functions
@@ -72,11 +72,11 @@
7272
nvp.StreamFun = []
7373
end
7474

75-
END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;
75+
URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;
7676

7777
parameters = buildParametersCall(messages, functions, nvp);
7878

79-
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
79+
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, URL, nvp.TimeOut, nvp.StreamFun);
8080

8181
% If call errors, "choices" will not be part of response.Body.Data, instead
8282
% we get response.Body.Data.error
@@ -108,9 +108,13 @@
108108

109109
parameters.stream = ~isempty(nvp.StreamFun);
110110

111-
parameters.tools = functions;
111+
if ~isempty(functions)
112+
parameters.tools = functions;
113+
end
112114

113-
parameters.tool_choice = nvp.ToolChoice;
115+
if ~isempty(nvp.ToolChoice)
116+
parameters.tool_choice = nvp.ToolChoice;
117+
end
114118

115119
if ~isempty(nvp.Seed)
116120
parameters.seed = nvp.Seed;

+llms/+internal/sendRequest.m

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun)
22
%sendRequest Sends a request to an ENDPOINT using PARAMETERS and
3-
% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial
3+
% api key TOKEN. TIMEOUT is the number of seconds to wait for initial
44
% server connection. STREAMFUN is an optional callback function.
55

66
% Copyright 2023 The MathWorks, Inc.
@@ -16,7 +16,8 @@
1616
% Define the headers for the API request
1717

1818
headers = [matlab.net.http.HeaderField('Content-Type', 'application/json')...
19-
matlab.net.http.HeaderField('Authorization', "Bearer " + token)];
19+
matlab.net.http.HeaderField('Authorization', "Bearer " + token)...
20+
matlab.net.http.HeaderField('api-key',token)];
2021

2122
% Define the request message
2223
request = matlab.net.http.RequestMessage('post',headers,parameters);

azureChat.m

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
classdef(Sealed) azureChat < llms.internal.textGenerator
22
%azureChat Chat completion API from Azure.
33
%
4-
% CHAT = azureChat(resourceName, deploymentID) creates an azureChat object with the
5-
% resource name and deployment ID path parameters required by Azure to establish the connection.
4+
% CHAT = azureChat(endpoint, deploymentID) creates an azureChat object with the
5+
% endpoint and deployment ID path parameters required by Azure to establish the connection.
66
%
77
% CHAT = azureChat(systemPrompt) creates an azureChatobject with the
88
% specified system prompt.
@@ -74,16 +74,15 @@
7474
% Copyright 2023-2024 The MathWorks, Inc.
7575

7676
properties(SetAccess=private)
77-
ResourceName
78-
DeploymentID
79-
APIVersion
77+
Endpoint (1,1) string
78+
DeploymentID (1,1) string
79+
APIVersion (1,1) string
8080
end
8181

82-
8382
methods
84-
function this = azureChat(resourceName, deploymentID, systemPrompt, nvp)
83+
function this = azureChat(endpoint, deploymentID, systemPrompt, nvp)
8584
arguments
86-
resourceName {mustBeTextScalar}
85+
endpoint {mustBeTextScalar}
8786
deploymentID {mustBeTextScalar}
8887
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
8988
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
@@ -123,7 +122,7 @@
123122
end
124123
end
125124

126-
this.ResourceName = resourceName;
125+
this.Endpoint = endpoint;
127126
this.DeploymentID = deploymentID;
128127
this.APIVersion = nvp.APIVersion;
129128
this.ResponseFormat = nvp.ResponseFormat;
@@ -181,7 +180,7 @@
181180
end
182181

183182
toolChoice = convertToolChoice(this, nvp.ToolChoice);
184-
[text, message, response] = llms.internal.callAzureChatAPI(this.ResourceName, ...
183+
[text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ...
185184
this.DeploymentID, messagesStruct, this.FunctionsStruct, ...
186185
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ...
187186
TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,...

tests/tazureChat.m

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,6 @@
33

44
% Copyright 2024 The MathWorks, Inc.
55

6-
methods (TestClassSetup)
7-
function saveEnvVar(testCase)
8-
% Ensures key is not in environment variable for tests
9-
azureKeyVar = "AZURE_OPENAI_API_KEY";
10-
if isenv(azureKeyVar)
11-
key = getenv(azureKeyVar);
12-
unsetenv(azureKeyVar);
13-
testCase.addTeardown(@(x) setenv(azureKeyVar, x), key);
14-
end
15-
end
16-
end
17-
186
properties(TestParameter)
197
InvalidConstructorInput = iGetInvalidConstructorInput;
208
InvalidGenerateInput = iGetInvalidGenerateInput;
@@ -24,11 +12,14 @@ function saveEnvVar(testCase)
2412
methods(Test)
2513
% Test methods
2614
function keyNotFound(testCase)
27-
testCase.verifyError(@()azureChat("My_resource", "Deployment"), "llms:keyMustBeSpecified");
15+
import matlab.unittest.fixtures.EnvironmentVariableFixture
16+
testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_API_KEY","dummy"));
17+
unsetenv("AZURE_OPENAI_API_KEY");
18+
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT")), "llms:keyMustBeSpecified");
2819
end
2920

3021
function constructChatWithAllNVP(testCase)
31-
resourceName = "resource";
22+
endpoint = getenv("AZURE_OPENAI_ENDPOINT");
3223
deploymentID = "hello";
3324
functions = openAIFunction("funName");
3425
temperature = 0;
@@ -39,7 +30,7 @@ function constructChatWithAllNVP(testCase)
3930
frequenceP = 2;
4031
systemPrompt = "This is a system prompt";
4132
timeout = 3;
42-
chat = azureChat(resourceName, deploymentID, systemPrompt, Tools=functions, ...
33+
chat = azureChat(endpoint, deploymentID, systemPrompt, Tools=functions, ...
4334
Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,...
4435
FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout);
4536
testCase.verifyEqual(chat.Temperature, temperature);
@@ -49,28 +40,36 @@ function constructChatWithAllNVP(testCase)
4940
testCase.verifyEqual(chat.PresencePenalty, presenceP);
5041
end
5142

43+
function doGenerate(testCase)
44+
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
45+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
46+
response = testCase.verifyWarningFree(@() generate(chat,"hi"));
47+
testCase.verifyClass(response,'string');
48+
testCase.verifyGreaterThan(strlength(response),0);
49+
end
50+
5251
function verySmallTimeOutErrors(testCase)
53-
chat = azureChat("My_resource", "Deployment", TimeOut=0.0001, ApiKey="false-key");
52+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), TimeOut=0.0001, ApiKey="false-key");
5453
testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout")
5554
end
5655

5756
function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
58-
chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key");
57+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key");
5958
testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
6059
end
6160

6261
function invalidInputsConstructor(testCase, InvalidConstructorInput)
63-
testCase.verifyError(@()azureChat("My_resource", "Deployment", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
62+
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
6463
end
6564

6665
function invalidInputsGenerate(testCase, InvalidGenerateInput)
6766
f = openAIFunction("validfunction");
68-
chat = azureChat("My_resource", "Deployment", Tools=f, ApiKey="this-is-not-a-real-key");
67+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), Tools=f, ApiKey="this-is-not-a-real-key");
6968
testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error);
7069
end
7170

7271
function invalidSetters(testCase, InvalidValuesSetters)
73-
chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key");
72+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key");
7473
function assignValueToProperty(property, value)
7574
chat.(property) = value;
7675
end

0 commit comments

Comments
 (0)