Skip to content

Add min-p sampling for ollamaChat #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
messages
nvp.Temperature
nvp.TopP
nvp.MinP
nvp.TopK
nvp.TailFreeSamplingZ
nvp.StopSequences
Expand Down Expand Up @@ -103,6 +104,7 @@
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("MinP") = "min_p";
dict("TopK") = "top_k";
dict("TailFreeSamplingZ") = "tfs_z";
dict("StopSequences") = "stop";
Expand Down
2 changes: 1 addition & 1 deletion +llms/+internal/textGenerator.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Temperature {llms.utils.mustBeValidTemperature} = 1

%TopP Top probability mass to consider for generation.
TopP {llms.utils.mustBeValidTopP} = 1
TopP {llms.utils.mustBeValidProbability} = 1

%StopSequences Sequences to stop the generation of tokens.
StopSequences {llms.utils.mustBeValidStop} = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function mustBeValidTopP(value)
function mustBeValidProbability(value)
% This function is undocumented and will change in a future release

% Copyright 2024 The MathWorks, Inc.
Expand Down
2 changes: 1 addition & 1 deletion azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01"
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0
Expand Down
13 changes: 11 additions & 2 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
% words can appear in any particular place.
% This is also known as top-p sampling.
%
% MinP - Minimum probability ratio for controlling the
% diversity of the output. Default value is 0;
% higher values imply that only the more likely
% words can appear in any particular place.
% This is also known as min-p sampling.
%
% TopK - Maximum number of most likely tokens that are
% considered for output. Default is Inf, allowing
% all tokens. Smaller values reduce diversity in
Expand Down Expand Up @@ -67,6 +73,7 @@
Model (1,1) string
Endpoint (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
MinP (1,1) {llms.utils.mustBeValidProbability} = 0
TailFreeSamplingZ (1,1) {mustBeReal} = 1
end

Expand All @@ -76,7 +83,8 @@
modelName {mustBeTextScalar}
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.MinP {llms.utils.mustBeValidProbability} = 0
nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
Expand All @@ -103,6 +111,7 @@
this.ResponseFormat = nvp.ResponseFormat;
this.Temperature = nvp.Temperature;
this.TopP = nvp.TopP;
this.MinP = nvp.MinP;
this.TopK = nvp.TopK;
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
this.StopSequences = nvp.StopSequences;
Expand Down Expand Up @@ -146,7 +155,7 @@
[text, message, response] = llms.internal.callOllamaChatAPI(...
this.Model, messagesStruct, ...
Temperature=this.Temperature, ...
TopP=this.TopP, TopK=this.TopK,...
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
Expand Down
2 changes: 1 addition & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) string {mustBeModel} = "gpt-4o-mini"
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.APIKey {mustBeNonzeroLengthTextScalar}
Expand Down
34 changes: 34 additions & 0 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ function extremeTopK(testCase)
testCase.verifyEqual(response1,response2);
end

function extremeMinP(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
%% honor it correctly.
testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");

% setting min-p to p=1 means only tokens with the same logit as
% the most likely one can be chosen, which will almost certainly
% only ever be one, so we expect to get a fixed response.
chat = ollamaChat("mistral",MinP=1);
prompt = "Min-p sampling with p=1 returns a definite answer.";
response1 = generate(chat,prompt);
response2 = generate(chat,prompt);
testCase.verifyEqual(response1,response2);
end

function extremeTfsZ(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
Expand Down Expand Up @@ -235,6 +251,16 @@ function queryModels(testCase)
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"MinPTooLarge", struct( ...
"Property", "MinP", ...
"Value", 20, ...
"Error", "MATLAB:notLessEqual"), ...
...
"MinPTooSmall", struct( ...
"Property", "MinP", ...
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"WrongTypeStopSequences", struct( ...
"Property", "StopSequences", ...
"Value", 123, ...
Expand Down Expand Up @@ -329,6 +355,14 @@ function queryModels(testCase)
"Input",{{ "TopP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"MinPTooLarge",struct( ...
"Input",{{ "MinP" 20 }},...
"Error","MATLAB:notLessEqual"),...
...
"MinPTooSmall",struct( ...
"Input",{{ "MinP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"WrongTypeStopSequences",struct( ...
"Input",{{ "StopSequences" 123}},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
Expand Down