Skip to content

Commit c28d6af

Browse files
committed
Add support for Ollama Generate API for using vision models
Signed-off-by: Demis Bellot <demis.bellot@gmail.com>
1 parent 4883b87 commit c28d6af

15 files changed

+729
-17
lines changed

AiServer.ServiceInterface/AiProviderFactory.cs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@ public interface IOpenAiProvider
1111
Task<OpenAiChatResult> ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token = default);
1212
}
1313

14-
public class AiProviderFactory(OpenAiProvider openAiProvider, GoogleAiProvider googleProvider, AnthropicAiProvider anthropicAiProvider)
14+
public record OllamaGenerationResult(OllamaGenerateResponse Response, int DurationMs);
15+
public interface IOllamaAiProvider
16+
{
17+
Task<OllamaGenerationResult> GenerateAsync(AiProvider provider, OllamaGenerate request, CancellationToken token = default);
18+
}
19+
20+
public class AiProviderFactory(OpenAiProvider openAiProvider, OllamaAiProvider ollamaAiProvider, GoogleAiProvider googleProvider, AnthropicAiProvider anthropicAiProvider)
1521
{
1622
public IOpenAiProvider GetOpenAiProvider(AiProviderType aiProviderType=AiProviderType.OpenAiProvider)
1723
{
18-
return aiProviderType == AiProviderType.GoogleAiProvider
19-
? googleProvider
20-
: aiProviderType == AiProviderType.AnthropicAiProvider
21-
? anthropicAiProvider
22-
: openAiProvider;
24+
return aiProviderType switch
25+
{
26+
AiProviderType.OllamaAiProvider => ollamaAiProvider,
27+
AiProviderType.GoogleAiProvider => googleProvider,
28+
AiProviderType.AnthropicAiProvider => anthropicAiProvider,
29+
_ => openAiProvider
30+
};
2331
}
2432
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System.Data;
2+
using AiServer.ServiceModel;
3+
using ServiceStack;
4+
using ServiceStack.Jobs;
5+
using ServiceStack.OrmLite;
6+
7+
namespace AiServer.ServiceInterface.AppDb;
8+
9+
public record class CompleteOllamaGenerate(QueueOllamaGeneration Request, OllamaGenerateResponse Response, BackgroundJob Job);
10+
11+
[Worker(Workers.AppDb)]
12+
public class CompleteOllamaGenerateCommand(IDbConnection db) : SyncCommand<CompleteOllamaGenerate>
13+
{
14+
protected override void Run(CompleteOllamaGenerate ctx)
15+
{
16+
var summary = new ChatSummary
17+
{
18+
Id = ctx.Job.Id,
19+
RefId = ctx.Job.RefId!,
20+
CreatedDate = ctx.Job.CreatedDate,
21+
DurationMs = ctx.Job.DurationMs,
22+
Tag = ctx.Job.Tag,
23+
Model = ctx.Request.Request.Model,
24+
Provider = ctx.Job.Worker!,
25+
PromptTokens = ctx.Response?.PromptTokens ?? 0,
26+
CompletionTokens = ctx.Response?.EvalCount ?? 0,
27+
};
28+
try
29+
{
30+
db.Insert(summary);
31+
}
32+
catch (Exception e)
33+
{
34+
// completing failed jobs could fail with unique constraint
35+
db.DeleteById<ChatSummary>(summary.Id);
36+
db.Insert(summary);
37+
}
38+
}
39+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using AiServer.ServiceInterface.AppDb;
2+
using AiServer.ServiceModel;
3+
using Microsoft.Extensions.Logging;
4+
using ServiceStack;
5+
using ServiceStack.Jobs;
6+
7+
namespace AiServer.ServiceInterface.Jobs;
8+
9+
public class CreateOllamaGenerationCommand(ILogger<CreateOllamaGenerationCommand> logger, IBackgroundJobs jobs, AppData appData, AiProviderFactory aiFactory, IHttpClientFactory clientFactory)
10+
: AsyncCommandWithResult<QueueOllamaGeneration, OllamaGenerateResponse>
11+
{
12+
protected override async Task<OllamaGenerateResponse> RunAsync(QueueOllamaGeneration request, CancellationToken token)
13+
{
14+
var job = Request.GetBackgroundJob();
15+
var log = Request.CreateJobLogger(jobs,logger);
16+
var apiProvider = appData.AssertAiProvider(job.Worker!);
17+
var chatProvider = aiFactory.GetOpenAiProvider(apiProvider.AiType.Provider);
18+
if (chatProvider is not IOllamaAiProvider generateProvider)
19+
throw new NotSupportedException($"{chatProvider.GetType()} is not an IOllamaAiProvider");
20+
21+
try
22+
{
23+
var origModel = request.Request.Model;
24+
request.Request.Model = appData.GetQualifiedModel(origModel) ?? origModel;
25+
log.LogInformation("GENERATE Ollama #{JobId} request for {OriginalModel}, using {Model}", job.Id, origModel, request.Request.Model);
26+
var (response, durationMs) = await generateProvider.GenerateAsync(apiProvider, request.Request, token);
27+
request.Request.Model = origModel;
28+
29+
job.DurationMs = durationMs;
30+
jobs.RunCommand<CompleteOllamaGenerateCommand>(
31+
new CompleteOllamaGenerate(Request: request, Response: response, Job: job));
32+
33+
log.LogInformation("GENERATE Ollama #{JobId} request finished in {Ms} ms{ReplyMessage}",
34+
job.Id, job.DurationMs, job.ReplyTo == null ? "" : $", sending response to {job.ReplyTo}");
35+
if (job.ReplyTo != null)
36+
{
37+
await clientFactory.SendJsonCallbackAsync(Request.GetBackgroundJob().ReplyTo!, request, token:token);
38+
// jobs.EnqueueCommand<NotifyOpenAiChatResponseCommand>(response, new() {
39+
// ParentId = job.Id,
40+
// ReplyTo = job.ReplyTo,
41+
// });
42+
}
43+
return response;
44+
}
45+
catch (Exception e)
46+
{
47+
var offline = !await chatProvider.IsOnlineAsync(apiProvider, token);
48+
log.LogError("CHAT OpenAi #{JobId} request failed after {Ms} with: {Message} (offline:{Offline})",
49+
job.Id, job.DurationMs, e.Message, offline);
50+
if (offline)
51+
{
52+
jobs.RunCommand<ChangeProviderStatusCommand>(new ChangeProviderStatus {
53+
Name = apiProvider.Name,
54+
OfflineDate = DateTime.UtcNow,
55+
});
56+
}
57+
throw;
58+
}
59+
}
60+
}

AiServer.ServiceInterface/Jobs/CreateOpenAiChatCommand.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ protected override async Task<OpenAiChatResponse> RunAsync(QueueOpenAiChatComple
3939
}
4040
return response;
4141
}
42-
catch(Exception e)
42+
catch (Exception e)
4343
{
4444
var offline = !await chatProvider.IsOnlineAsync(apiProvider, token);
4545
log.LogError("CHAT OpenAi #{JobId} request failed after {Ms} with: {Message} (offline:{Offline})",

AiServer.ServiceInterface/OpenAiChatServices.cs

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ public object Any(GetModelImage request)
108108
return GetModelImage(request.Model);
109109
}
110110

111+
public async Task<object> Post(OllamaGeneration request)
112+
{
113+
var generateRequest = new QueueOllamaGeneration
114+
{
115+
Request = request,
116+
RefId = request.RefId,
117+
Tag = request.Tag,
118+
Provider = request.Provider
119+
};
120+
121+
return await generateRequest.ProcessSync(jobs, this);
122+
}
123+
111124
public async Task<object> Post(OpenAiChatCompletion request)
112125
{
113126
var chatRequest = new QueueOpenAiChatCompletion
@@ -121,6 +134,70 @@ public async Task<object> Post(OpenAiChatCompletion request)
121134
return await chatRequest.ProcessSync(jobs, this);
122135
}
123136

137+
public QueueOllamaGenerationResponse Any(QueueOllamaGeneration request)
138+
{
139+
if (request.Request == null)
140+
throw new ArgumentNullException(nameof(request.Request));
141+
142+
if (request.Request.Prompt.IsNullOrEmpty())
143+
throw new ArgumentNullException(nameof(request.Request.Prompt));
144+
145+
var qualifiedModel = appData.GetQualifiedModel(request.Request.Model);
146+
if (qualifiedModel == null)
147+
throw HttpError.NotFound($"Model {request.Request.Model} not found");
148+
149+
var queueCounts = jobs.GetWorkerQueueCounts();
150+
var providerQueueCount = int.MaxValue;
151+
AiProvider? useProvider = null;
152+
var candidates = appData.AiProviders
153+
.Where(x => x is { Enabled: true, AiType.Provider: AiProviderType.OllamaAiProvider }
154+
&& x.Models.Any(m => m.Model == qualifiedModel)).ToList();
155+
foreach (var candidate in candidates)
156+
{
157+
if (candidate.OfflineDate != null)
158+
continue;
159+
var pendingJobs = queueCounts.GetValueOrDefault(candidate.Name, 0);
160+
if (useProvider == null)
161+
{
162+
useProvider = candidate;
163+
providerQueueCount = pendingJobs;
164+
continue;
165+
}
166+
if (pendingJobs < providerQueueCount || (pendingJobs == providerQueueCount && candidate.Priority > useProvider.Priority))
167+
{
168+
useProvider = candidate;
169+
providerQueueCount = pendingJobs;
170+
}
171+
}
172+
173+
useProvider ??= candidates.FirstOrDefault(x => x.Name == qualifiedModel); // Allow selecting offline models
174+
if (useProvider == null)
175+
throw new NotSupportedException("No active AI Providers support this model");
176+
177+
var jobRef = jobs.EnqueueCommand<CreateOllamaGenerationCommand>(request, new()
178+
{
179+
RefId = request.RefId,
180+
ReplyTo = request.ReplyTo,
181+
Tag = request.Tag,
182+
Args = request.Provider == null ? null : new() {
183+
[nameof(request.Provider)] = request.Provider
184+
},
185+
Worker = useProvider.Name,
186+
});
187+
188+
var jobStatusUrl = AppConfig.Instance.ApplicationBaseUrl
189+
.CombineWith($"/api/{nameof(GetOllamaGenerationStatus)}?RefId=" + jobRef.RefId);
190+
191+
var response = new QueueOllamaGenerationResponse
192+
{
193+
Id = jobRef.Id,
194+
RefId = jobRef.RefId,
195+
StatusUrl = jobStatusUrl
196+
};
197+
198+
return response;
199+
}
200+
124201
public QueueOpenAiChatResponse Any(QueueOpenAiChatCompletion request)
125202
{
126203
if (request.Request == null)
@@ -292,6 +369,40 @@ private async Task<JobResult> WaitForJobCompletion(long jobId)
292369
return HttpError.NotFound("Job not found");
293370
}
294371

372+
public async Task<object> Get(GetOllamaGenerationStatus request)
373+
{
374+
var summary = GetJobSummary((int)request.JobId, request.RefId);
375+
if (summary == null)
376+
return HttpError.NotFound("JobSummary not found");
377+
378+
var response = GetOpenAiChat(summary);
379+
if (response == null)
380+
return HttpError.NotFound("Job not found");
381+
382+
var job = response.Result;
383+
384+
var generateResponse = response.Result?.ResponseBody.FromJson<OllamaGenerateResponse>();
385+
if (generateResponse == null)
386+
{
387+
return new GetOllamaGenerationStatusResponse
388+
{
389+
JobId = request.JobId,
390+
RefId = request.RefId,
391+
JobState = job.State,
392+
Status = job.State.ToString(),
393+
};
394+
}
395+
396+
return new GetOllamaGenerationStatusResponse
397+
{
398+
JobId = request.JobId,
399+
RefId = request.RefId,
400+
JobState = job.State,
401+
Status = job.State.ToString(),
402+
Result = generateResponse,
403+
};
404+
}
405+
295406
public async Task<object> Get(GetOpenAiChatStatus request)
296407
{
297408
var summary = GetJobSummary((int)request.JobId, request.RefId);
@@ -428,6 +539,81 @@ public object Any(DeleteAiProvider request)
428539

429540
public static class OpenAiChatServiceExtensions
430541
{
542+
public static async Task<OllamaGenerateResponse> ProcessSync(this QueueOllamaGeneration generateRequest,
543+
IBackgroundJobs jobs, OpenAiChatServices chatService)
544+
{
545+
QueueOllamaGenerationResponse? generateResponse = null;
546+
try
547+
{
548+
var response = chatService.Any(generateRequest);
549+
generateResponse = response;
550+
}
551+
catch (Exception e)
552+
{
553+
Console.WriteLine(e);
554+
throw;
555+
}
556+
557+
if (generateResponse == null)
558+
throw new Exception("Failed to start chat request");
559+
560+
var job = jobs.GetJob(generateResponse.Id);
561+
// For all requests, wait for the job to be created
562+
while (job == null)
563+
{
564+
await Task.Delay(1000);
565+
job = jobs.GetJob(generateResponse.Id);
566+
}
567+
568+
// We know at this point, we definitely have a job
569+
JobResult queuedJob = job;
570+
571+
var completedResponse = new OllamaGenerateResponse();
572+
573+
// Handle failed jobs
574+
if (job.Failed != null)
575+
{
576+
throw new Exception(job.Failed.Error!.Message);
577+
}
578+
579+
// Wait for the job to complete max 2 minutes
580+
var timeout = DateTime.UtcNow.AddMinutes(2);
581+
while (queuedJob?.Job?.State is not (BackgroundJobState.Completed or BackgroundJobState.Cancelled
582+
or BackgroundJobState.Failed) && DateTime.UtcNow < timeout)
583+
{
584+
await Task.Delay(1000);
585+
queuedJob = jobs.GetJob(generateResponse.Id);
586+
}
587+
588+
// Check if the job is still not completed
589+
if (queuedJob?.Job?.State is not (BackgroundJobState.Completed or BackgroundJobState.Cancelled
590+
or BackgroundJobState.Failed))
591+
{
592+
throw new Exception("Job did not complete within the specified timeout.");
593+
}
594+
595+
// Process successful job results
596+
var jobResponseBody = queuedJob.Completed?.ResponseBody;
597+
var jobRes = jobResponseBody.FromJson<OllamaGenerateResponse>();
598+
if (jobRes != null)
599+
{
600+
completedResponse.Model = jobRes.Model;
601+
completedResponse.CreatedAt = jobRes.CreatedAt;
602+
completedResponse.Response = jobRes.Response;
603+
completedResponse.Done = jobRes.Done;
604+
completedResponse.Context = jobRes.Context;
605+
completedResponse.DoneReason = jobRes.DoneReason;
606+
completedResponse.TotalDuration = jobRes.TotalDuration;
607+
completedResponse.LoadDuration = jobRes.LoadDuration;
608+
completedResponse.PromptEvalCount = jobRes.PromptEvalCount;
609+
completedResponse.EvalCount = jobRes.EvalCount;
610+
completedResponse.PromptTokens = jobRes.PromptTokens;
611+
completedResponse.ResponseStatus = jobRes.ResponseStatus;
612+
}
613+
614+
return completedResponse;
615+
}
616+
431617
public static async Task<OpenAiChatResponse> ProcessSync(this QueueOpenAiChatCompletion chatRequest,
432618
IBackgroundJobs jobs, OpenAiChatServices chatService)
433619
{
@@ -443,7 +629,7 @@ public static async Task<OpenAiChatResponse> ProcessSync(this QueueOpenAiChatCom
443629
throw;
444630
}
445631

446-
if(chatResponse == null)
632+
if (chatResponse == null)
447633
throw new Exception("Failed to start chat request");
448634

449635
var job = jobs.GetJob(chatResponse.Id);

0 commit comments

Comments
 (0)