Skip to content

Add support for remote ollama #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@
nvp.Seed
nvp.TimeOut
nvp.StreamFun
nvp.Endpoint
end

URL = "http://localhost:11434/api/chat";
URL = nvp.Endpoint + "/api/chat";
if ~startsWith(URL,"http")
URL = "http://" + URL;
end

% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
% The easiest way to ensure that is to never pass in a scalar …
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@ jobs:
run: |
# Run the background, there is no way to daemonise at the moment
ollama serve &
# Run a second server to test different endpoint
OLLAMA_HOST=127.0.0.1:11435 OLLAMA_MODELS=/tmp/ollama/models ollama serve &

# A short pause is required before the HTTP port is opened
sleep 5

# This endpoint blocks until ready
time curl -i http://localhost:11434
time curl -i http://localhost:11435

# For debugging, record Ollama version
ollama --version

- name: Pull mistral model
- name: Pull models
run: |
ollama pull mistral
OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b
- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2
with:
Expand All @@ -39,6 +43,7 @@ jobs:
AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_DEPLOYMENT }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_KEY }}
SECOND_OLLAMA_ENDPOINT: 127.0.0.1:11435
uses: matlab-actions/run-tests@v2
with:
test-results-junit: test-results/results.xml
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ papers_to_read.csv
data/*
examples/data/*
._*
.nfs*
.DS_Store
7 changes: 7 additions & 0 deletions doc/Ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,10 @@ chat = ollamaChat("mistral", StreamFun=sf);
txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?");
% Should stream the response token by token
```

## Establishing a connection to remote LLMs using Ollama

To connect to a remote Ollama server, use the `Endpoint` parameter. Include the server name and port number (Ollama starts on 11434 by default):
```matlab
chat = ollamaChat("mistral",Endpoint="ollamaServer:11434");
```
10 changes: 7 additions & 3 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@
% Copyright 2024 The MathWorks, Inc.

properties
Model (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
Model (1,1) string
Endpoint (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
TailFreeSamplingZ (1,1) {mustBeReal} = 1
end

Expand All @@ -82,6 +83,7 @@
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 120
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = 1
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
nvp.Endpoint (1,1) string = "127.0.0.1:11434"
end

if isfield(nvp,"StreamFun")
Expand All @@ -105,6 +107,7 @@
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
this.StopSequences = nvp.StopSequences;
this.TimeOut = nvp.TimeOut;
this.Endpoint = nvp.Endpoint;
end

function [text, message, response] = generate(this, messages, nvp)
Expand Down Expand Up @@ -147,7 +150,8 @@
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
TimeOut=this.TimeOut, StreamFun=this.StreamFun);
TimeOut=this.TimeOut, StreamFun=this.StreamFun, ...
Endpoint=this.Endpoint);

if isfield(response.Body.Data,"error")
err = response.Body.Data.error;
Expand Down
13 changes: 11 additions & 2 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ function seedFixesResult(testCase)
testCase.verifyEqual(response1,response2);
end


function streamFunc(testCase)
function seen = sf(str)
persistent data;
Expand All @@ -118,6 +117,17 @@ function streamFunc(testCase)
testCase.verifyGreaterThan(numel(sf("")), 1);
end

function reactToEndpoint(testCase)
testCase.assumeTrue(isenv("SECOND_OLLAMA_ENDPOINT"),...
"Test point assumes a second Ollama server is running " + ...
"and $SECOND_OLLAMA_ENDPOINT points to it.");
chat = ollamaChat("qwen2:0.5b",Endpoint=getenv("SECOND_OLLAMA_ENDPOINT"));
testCase.verifyWarningFree(@() generate(chat,"dummy"));
% also make sure "http://" can be included
chat = ollamaChat("qwen2:0.5b",Endpoint="http://" + getenv("SECOND_OLLAMA_ENDPOINT"));
testCase.verifyWarningFree(@() generate(chat,"dummy"));
end

function doReturnErrors(testCase)
testCase.assumeFalse( ...
any(startsWith(ollamaChat.models,"abcdefghijklmnop")), ...
Expand All @@ -126,7 +136,6 @@ function doReturnErrors(testCase)
testCase.verifyError(@() generate(chat,"hi!"), "llms:apiReturnedError");
end


function invalidInputsConstructor(testCase, InvalidConstructorInput)
testCase.verifyError(@() ollamaChat("mistral", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
end
Expand Down