Skip to content

Commit d127953

Browse files
authored
Merge pull request #77 from matlab-deep-learning/minp
Add min-p sampling for `ollamaChat`
2 parents 150d9c1 + b0023dc commit d127953

File tree

7 files changed

+51
-6
lines changed

7 files changed

+51
-6
lines changed

+llms/+internal/callOllamaChatAPI.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
messages
3030
nvp.Temperature
3131
nvp.TopP
32+
nvp.MinP
3233
nvp.TopK
3334
nvp.TailFreeSamplingZ
3435
nvp.StopSequences
@@ -103,6 +104,7 @@
103104
dict = dictionary();
104105
dict("Temperature") = "temperature";
105106
dict("TopP") = "top_p";
107+
dict("MinP") = "min_p";
106108
dict("TopK") = "top_k";
107109
dict("TailFreeSamplingZ") = "tfs_z";
108110
dict("StopSequences") = "stop";

+llms/+internal/textGenerator.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Temperature {llms.utils.mustBeValidTemperature} = 1
99

1010
%TopP Top probability mass to consider for generation.
11-
TopP {llms.utils.mustBeValidTopP} = 1
11+
TopP {llms.utils.mustBeValidProbability} = 1
1212

1313
%StopSequences Sequences to stop the generation of tokens.
1414
StopSequences {llms.utils.mustBeValidStop} = {}

+llms/+utils/mustBeValidTopP.m renamed to +llms/+utils/mustBeValidProbability.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function mustBeValidTopP(value)
1+
function mustBeValidProbability(value)
22
% This function is undocumented and will change in a future release
33

44
% Copyright 2024 The MathWorks, Inc.

azureChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
110110
nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01"
111111
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
112-
nvp.TopP {llms.utils.mustBeValidTopP} = 1
112+
nvp.TopP {llms.utils.mustBeValidProbability} = 1
113113
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
114114
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
115115
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0

ollamaChat.m

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
% words can appear in any particular place.
2424
% This is also known as top-p sampling.
2525
%
26+
% MinP - Minimum probability ratio for controlling the
27+
% diversity of the output. Default value is 0;
28+
% higher values imply that only the more likely
29+
% words can appear in any particular place.
30+
% This is also known as min-p sampling.
31+
%
2632
% TopK - Maximum number of most likely tokens that are
2733
% considered for output. Default is Inf, allowing
2834
% all tokens. Smaller values reduce diversity in
@@ -67,6 +73,7 @@
6773
Model (1,1) string
6874
Endpoint (1,1) string
6975
TopK (1,1) {mustBeReal,mustBePositive} = Inf
76+
MinP (1,1) {llms.utils.mustBeValidProbability} = 0
7077
TailFreeSamplingZ (1,1) {mustBeReal} = 1
7178
end
7279

@@ -76,7 +83,8 @@
7683
modelName {mustBeTextScalar}
7784
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
7885
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
79-
nvp.TopP {llms.utils.mustBeValidTopP} = 1
86+
nvp.TopP {llms.utils.mustBeValidProbability} = 1
87+
nvp.MinP {llms.utils.mustBeValidProbability} = 0
8088
nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf
8189
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
8290
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
@@ -103,6 +111,7 @@
103111
this.ResponseFormat = nvp.ResponseFormat;
104112
this.Temperature = nvp.Temperature;
105113
this.TopP = nvp.TopP;
114+
this.MinP = nvp.MinP;
106115
this.TopK = nvp.TopK;
107116
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
108117
this.StopSequences = nvp.StopSequences;
@@ -146,7 +155,7 @@
146155
[text, message, response] = llms.internal.callOllamaChatAPI(...
147156
this.Model, messagesStruct, ...
148157
Temperature=this.Temperature, ...
149-
TopP=this.TopP, TopK=this.TopK,...
158+
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
150159
TailFreeSamplingZ=this.TailFreeSamplingZ,...
151160
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
152161
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...

openAIChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
9595
nvp.ModelName (1,1) string {mustBeModel} = "gpt-4o-mini"
9696
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
97-
nvp.TopP {llms.utils.mustBeValidTopP} = 1
97+
nvp.TopP {llms.utils.mustBeValidProbability} = 1
9898
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
9999
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
100100
nvp.APIKey {mustBeNonzeroLengthTextScalar}

tests/tollamaChat.m

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ function extremeTopK(testCase)
6161
testCase.verifyEqual(response1,response2);
6262
end
6363

64+
function extremeMinP(testCase)
65+
%% This should work, and it does on some computers. On others, Ollama
66+
%% receives the parameter, but either Ollama or llama.cpp fails to
67+
%% honor it correctly.
68+
testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");
69+
70+
% setting min-p to p=1 means only tokens with the same logit as
71+
% the most likely one can be chosen, which will almost certainly
72+
% only ever be one, so we expect to get a fixed response.
73+
chat = ollamaChat("mistral",MinP=1);
74+
prompt = "Min-p sampling with p=1 returns a definite answer.";
75+
response1 = generate(chat,prompt);
76+
response2 = generate(chat,prompt);
77+
testCase.verifyEqual(response1,response2);
78+
end
79+
6480
function extremeTfsZ(testCase)
6581
%% This should work, and it does on some computers. On others, Ollama
6682
%% receives the parameter, but either Ollama or llama.cpp fails to
@@ -235,6 +251,16 @@ function queryModels(testCase)
235251
"Value", -20, ...
236252
"Error", "MATLAB:expectedNonnegative"), ...
237253
...
254+
"MinPTooLarge", struct( ...
255+
"Property", "MinP", ...
256+
"Value", 20, ...
257+
"Error", "MATLAB:notLessEqual"), ...
258+
...
259+
"MinPTooSmall", struct( ...
260+
"Property", "MinP", ...
261+
"Value", -20, ...
262+
"Error", "MATLAB:expectedNonnegative"), ...
263+
...
238264
"WrongTypeStopSequences", struct( ...
239265
"Property", "StopSequences", ...
240266
"Value", 123, ...
@@ -329,6 +355,14 @@ function queryModels(testCase)
329355
"Input",{{ "TopP" -20 }},...
330356
"Error","MATLAB:expectedNonnegative"),...I
331357
...
358+
"MinPTooLarge",struct( ...
359+
"Input",{{ "MinP" 20 }},...
360+
"Error","MATLAB:notLessEqual"),...
361+
...
362+
"MinPTooSmall",struct( ...
363+
"Input",{{ "MinP" -20 }},...
364+
"Error","MATLAB:expectedNonnegative"),...I
365+
...
332366
"WrongTypeStopSequences",struct( ...
333367
"Input",{{ "StopSequences" 123}},...
334368
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...

0 commit comments

Comments
 (0)