Skip to content

Commit f4f97b9

Browse files
authored
Merge pull request #36 from matlab-deep-learning/refactor-2024-05-16
little bit of code reformatting
2 parents dbd8b70 + 11303d6 commit f4f97b9

File tree

1 file changed

+16
-60
lines changed

1 file changed

+16
-60
lines changed

openAIChat.m

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@
7070

7171
properties
7272
%TEMPERATURE Temperature of generation.
73-
Temperature
73+
Temperature {mustBeValidTemperature} = 1
7474

7575
%TOPPROBABILITYMASS Top probability mass to consider for generation.
76-
TopProbabilityMass
76+
TopProbabilityMass {mustBeValidTopP} = 1
7777

7878
%STOPSEQUENCES Sequences to stop the generation of tokens.
79-
StopSequences
79+
StopSequences {mustBeValidStop} = {}
8080

8181
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
82-
PresencePenalty
82+
PresencePenalty {mustBeValidPenalty} = 0
8383

8484
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
85-
FrequencyPenalty
85+
FrequencyPenalty {mustBeValidPenalty} = 0
8686
end
8787

8888
properties(SetAccess=private)
@@ -114,11 +114,13 @@
114114
arguments
115115
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
116116
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
117-
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4o", ...
118-
"gpt-4o-2024-05-13","gpt-4-turbo", ...
119-
"gpt-4-turbo-2024-04-09","gpt-4","gpt-4-0613", ...
120-
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
121-
"gpt-3.5-turbo-1106"])} = "gpt-3.5-turbo"
117+
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,[...
118+
"gpt-4o","gpt-4o-2024-05-13",...
119+
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
120+
"gpt-4","gpt-4-0613", ...
121+
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
122+
"gpt-3.5-turbo-1106",...
123+
])} = "gpt-3.5-turbo"
122124
nvp.Temperature {mustBeValidTemperature} = 1
123125
nvp.TopProbabilityMass {mustBeValidTopP} = 1
124126
nvp.StopSequences {mustBeValidStop} = {}
@@ -147,7 +149,7 @@
147149

148150
if ~isempty(systemPrompt)
149151
systemPrompt = string(systemPrompt);
150-
if ~(strlength(systemPrompt)==0)
152+
if systemPrompt ~= ""
151153
this.SystemPrompt = {struct("role", "system", "content", systemPrompt)};
152154
end
153155
end
@@ -158,7 +160,7 @@
158160
this.StopSequences = nvp.StopSequences;
159161

160162
% ResponseFormat is only supported in the latest models only
161-
if (nvp.ResponseFormat == "json")
163+
if nvp.ResponseFormat == "json"
162164
if ismember(this.ModelName,["gpt-4","gpt-4-0613"])
163165
error("llms:invalidOptionAndValueForModel", ...
164166
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", this.ModelName));
@@ -243,52 +245,6 @@
243245
end
244246

245247
end
246-
247-
function this = set.Temperature(this, temperature)
248-
arguments
249-
this openAIChat
250-
temperature
251-
end
252-
mustBeValidTemperature(temperature);
253-
254-
this.Temperature = temperature;
255-
end
256-
257-
function this = set.TopProbabilityMass(this,topP)
258-
arguments
259-
this openAIChat
260-
topP
261-
end
262-
mustBeValidTopP(topP);
263-
this.TopProbabilityMass = topP;
264-
end
265-
266-
function this = set.StopSequences(this,stop)
267-
arguments
268-
this openAIChat
269-
stop
270-
end
271-
mustBeValidStop(stop);
272-
this.StopSequences = stop;
273-
end
274-
275-
function this = set.PresencePenalty(this,penalty)
276-
arguments
277-
this openAIChat
278-
penalty
279-
end
280-
mustBeValidPenalty(penalty)
281-
this.PresencePenalty = penalty;
282-
end
283-
284-
function this = set.FrequencyPenalty(this,penalty)
285-
arguments
286-
this openAIChat
287-
penalty
288-
end
289-
mustBeValidPenalty(penalty)
290-
this.FrequencyPenalty = penalty;
291-
end
292248
end
293249

294250
methods(Hidden)
@@ -331,7 +287,7 @@ function mustBeNonzeroLengthTextScalar(content)
331287

332288
for i = 1:numFunctions
333289
functionsStruct{i} = struct('type','function', ...
334-
'function',encodeStruct(functions(i))) ;
290+
'function',encodeStruct(functions(i)));
335291
functionNames(i) = functions(i).FunctionName;
336292
end
337293
end
@@ -377,4 +333,4 @@ function mustBeIntegerOrEmpty(value)
377333
if ~isempty(value)
378334
mustBeInteger(value)
379335
end
380-
end
336+
end

0 commit comments

Comments
 (0)