Skip to content

Commit 500981e

Browse files
committed
Add support for using Ollama Vision models in ImageToText APIs and UI
Signed-off-by: Demis Bellot <demis.bellot@gmail.com>
1 parent c28d6af commit 500981e

File tree

10 files changed

+531
-84
lines changed

10 files changed

+531
-84
lines changed

AiServer.ServiceInterface/GenerationServices.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace AiServer.ServiceInterface;
99

10-
public class GenerationServices(IBackgroundJobs jobs, AppData appData) : Service
10+
public class GenerationServices(IBackgroundJobs jobs, AppData appData, AppConfig appConfig) : Service
1111
{
1212
public GetTextGenerationStatusResponse Any(GetTextGenerationStatus request)
1313
{
@@ -169,6 +169,38 @@ public async Task<ArtifactGenerationResponse> Any(ImageWithMask request)
169169

170170
public async Task<TextGenerationResponse> Any(ImageToText request)
171171
{
172+
if (request is { Model: not null, Prompt: not null })
173+
{
174+
var filesMap = Request!.HandleFileUploads(appConfig);
175+
if (filesMap.Count == 0)
176+
throw new ArgumentNullException(nameof(request.Image));
177+
178+
var file = Request!.Files.First(x => x.Name == filesMap.First().Key);
179+
var fileBytes = file.InputStream.ReadFully();
180+
var generateRequest = new QueueOllamaGeneration
181+
{
182+
Request = new()
183+
{
184+
Model = request.Model,
185+
Prompt = request.Prompt,
186+
Images = [Convert.ToBase64String(fileBytes)],
187+
},
188+
RefId = request.RefId,
189+
Tag = request.Tag,
190+
};
191+
192+
await using var chatService = HostContext.ResolveService<OpenAiChatServices>(Request);
193+
var response = await generateRequest.ProcessSync(jobs, chatService);
194+
195+
return new TextGenerationResponse
196+
{
197+
Results = response.Response != null ? [
198+
new() { Text = response.Response }
199+
] : null,
200+
ResponseStatus = response.ResponseStatus
201+
};
202+
}
203+
172204
var diffRequest = new CreateGeneration
173205
{
174206
Request = new()

AiServer.ServiceInterface/MediaProviderServices.cs

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public object Any(CreateGeneration request)
159159
if (Request != null && Request.Files.Length > 0)
160160
{
161161
log.LogInformation("Saving {Count} uploaded files", Request.Files.Length);
162-
var fileMap = HandleFileUploads(Request);
162+
var fileMap = Request.HandleFileUploads(appConfig);
163163
args["Files"] = fileMap.ToJson();
164164
}
165165

@@ -178,48 +178,6 @@ public object Any(CreateGeneration request)
178178
};
179179
}
180180

181-
private Dictionary<string,string> HandleFileUploads(IRequest originalRequest)
182-
{
183-
var apiKeyId = (Request.GetApiKey() as ApiKeysFeature.ApiKey)?.Id ?? 0;
184-
var now = DateTime.UtcNow;
185-
var fileMap = new Dictionary<string, string>();
186-
foreach (var file in originalRequest.Files.Where(x =>
187-
supportedUploadNames.Contains(x.Name.ToLower()))
188-
)
189-
{
190-
var ext = file.FileName.Contains(".") ? file.FileName.SplitOnLast(".").Last() : GetFileExtension(file.ContentType);
191-
var memCpy = new MemoryStream();
192-
file.InputStream.CopyTo(memCpy);
193-
memCpy.Position = 0;
194-
var bytes = memCpy.ReadFully();
195-
var sha256 = bytes.ComputeSha256();
196-
var uploadDir = "/input/" + now.ToString("yyyy/MM/dd/") + apiKeyId + "/";
197-
Directory.CreateDirectory(appConfig.ArtifactsPath.CombineWith(uploadDir));
198-
var uploadFile = appConfig.ArtifactsPath.CombineWith(uploadDir, sha256 + "." + ext);
199-
using (var fs = File.OpenWrite(uploadFile))
200-
{
201-
fs.Write(bytes, 0, bytes.Length);
202-
fs.Close();
203-
}
204-
file.InputStream.Position = 0; // Reset the stream position if needed for further processing
205-
fileMap[file.Name] = uploadFile;
206-
}
207-
return fileMap;
208-
}
209-
210-
private string GetFileExtension(string contentType)
211-
{
212-
return contentType switch
213-
{
214-
"audio/mpeg" => "mp3",
215-
"audio/x-wav" => "wav",
216-
"audio/wav" => "wav",
217-
_ => "webp"
218-
};
219-
}
220-
221-
private static string[] supportedUploadNames = ["audio", "image", "mask"];
222-
223181
public object Any(QueryMediaModels request)
224182
{
225183
// Ensure all model settings have Id as related model name
@@ -386,4 +344,50 @@ public static class BackgroundJobsFeatureExtensions
386344
var summary = db.Single(q);
387345
return summary;
388346
}
347+
}
348+
349+
public static class MediaProviderExtensions
350+
{
351+
public static Dictionary<string,string> HandleFileUploads(this IRequest request, AppConfig appConfig)
352+
{
353+
var apiKeyId = (request.GetApiKey() as ApiKeysFeature.ApiKey)?.Id ?? 0;
354+
var now = DateTime.UtcNow;
355+
var fileMap = new Dictionary<string, string>();
356+
foreach (var file in request.Files.Where(x =>
357+
supportedUploadNames.Contains(x.Name.ToLower())))
358+
{
359+
var ext = file.FileName.Contains('.')
360+
? file.FileName.SplitOnLast(".").Last()
361+
: GetFileExtension(file.ContentType);
362+
var memCpy = new MemoryStream();
363+
file.InputStream.CopyTo(memCpy);
364+
memCpy.Position = 0;
365+
var bytes = memCpy.ReadFully();
366+
var sha256 = bytes.ComputeSha256();
367+
var uploadDir = "/input/" + now.ToString("yyyy/MM/dd/") + apiKeyId + "/";
368+
Directory.CreateDirectory(appConfig.ArtifactsPath.CombineWith(uploadDir));
369+
var uploadFile = appConfig.ArtifactsPath.CombineWith(uploadDir, sha256 + "." + ext);
370+
using (var fs = File.OpenWrite(uploadFile))
371+
{
372+
fs.Write(bytes, 0, bytes.Length);
373+
fs.Close();
374+
}
375+
file.InputStream.Position = 0; // Reset the stream position if needed for further processing
376+
fileMap[file.Name] = uploadFile;
377+
}
378+
return fileMap;
379+
}
380+
381+
private static string GetFileExtension(string contentType)
382+
{
383+
return contentType switch
384+
{
385+
"audio/mpeg" => "mp3",
386+
"audio/x-wav" => "wav",
387+
"audio/wav" => "wav",
388+
_ => "webp"
389+
};
390+
}
391+
392+
private static string[] supportedUploadNames = ["audio", "image", "mask"];
389393
}

AiServer.ServiceInterface/OpenAiChatServices.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,27 @@ public object Any(QueryAiTypes request)
3838
public object Any(ActiveAiModels request)
3939
{
4040
var activeModels = appData.AiProviders
41+
.Where(x => request.Provider == null || x.AiType?.Provider == request.Provider)
4142
.SelectMany(x => x.Models.Select(m => appData.GetQualifiedModel(m.Model)))
4243
.Where(x => x != null)
4344
.Select(x => x!) // Non-null assertion after filtering out null values
4445
.Distinct()
45-
.OrderBy(x => x);
46+
.OrderBy(x => x)
47+
.ToList();
48+
49+
if (request.Vision == true)
50+
{
51+
var allVisionModels = appData.AiModels
52+
.GetAll()
53+
.Where(x => x.Vision == true)
54+
.Select(x => x.Id)
55+
.ToSet();
56+
activeModels = activeModels.Where(x => allVisionModels.Contains(x.LeftPart(':'))).ToList();
57+
}
4658

4759
return new StringsResponse
4860
{
49-
Results = activeModels.ToList()
61+
Results = activeModels
5062
};
5163
}
5264

AiServer.ServiceModel/ApiAdmin.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ public class QueryAiTypes : QueryDb<AiType> {}
1313

1414
[Tag(Tags.AiInfo)]
1515
[Api("Active AI Worker Models available in AI Server")]
16-
public class ActiveAiModels : IGet, IReturn<StringsResponse> {}
16+
public class ActiveAiModels : IGet, IReturn<StringsResponse>
17+
{
18+
public AiProviderType? Provider { get; set; }
19+
public bool? Vision { get; set; }
20+
}
1721

1822

1923
[Tag(Tags.AiInfo)]

AiServer.ServiceModel/Generations.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ public class ImageToText : IGeneration, IReturn<TextGenerationResponse>
210210
[Input(Type = "file")]
211211
public string? Image { get; set; }
212212

213+
[ApiMember(Description = "Whether to use a Vision Model for the request")]
214+
public string? Model { get; set; }
215+
216+
[ApiMember(Description = "Prompt for the vision model")]
217+
public string? Prompt { get; set; }
218+
213219
[ApiMember(Description = "Optional client-provided identifier for the request")]
214220
public string? RefId { get; set; }
215221

AiServer/wwwroot/css/app.css

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
}
108108

109109
/*
110-
! tailwindcss v3.4.15 | MIT License | https://tailwindcss.com
110+
! tailwindcss v3.4.17 | MIT License | https://tailwindcss.com
111111
*/
112112

113113
/*
@@ -1022,6 +1022,10 @@ select{
10221022
grid-column: span 6 / span 6;
10231023
}
10241024

1025+
.col-span-4 {
1026+
grid-column: span 4 / span 4;
1027+
}
1028+
10251029
.row-span-2 {
10261030
grid-row: span 2 / span 2;
10271031
}
@@ -1449,6 +1453,10 @@ select{
14491453
max-height: 15rem;
14501454
}
14511455

1456+
.max-h-64 {
1457+
max-height: 16rem;
1458+
}
1459+
14521460
.max-h-\[25dvh\] {
14531461
max-height: 25dvh;
14541462
}

AiServer/wwwroot/mjs/components/Chat.mjs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { ref, computed, onMounted, inject, watch, nextTick } from "vue"
22
import { useClient } from "@servicestack/vue"
3-
import { marked } from "../markdown.mjs"
3+
import { renderMarkdown } from "../markdown.mjs"
44
import { addCopyButtonToCodeBlocks } from "../dom.mjs"
55
import { useUiLayout, UiLayout, ThreadStorage, HistoryTitle, HistoryGroups } from "../utils.mjs"
66
import { QueryPrompts, ActiveAiModels, OpenAiChatCompletion } from "../dtos.mjs"
@@ -331,16 +331,6 @@ export default {
331331
}
332332
}
333333

334-
function renderMarkdown(content) {
335-
if (content) {
336-
console.log(content)
337-
content = content
338-
.replaceAll(`\\[ \\boxed{`,'\n<span class="inline-block text-xl text-blue-500 bg-blue-50 px-3 py-1 rounded">')
339-
.replaceAll('} \\]','</span>\n')
340-
}
341-
return marked.parse(content)
342-
}
343-
344334
watch(() => routes.id, updated)
345335
watch(() => selectedPrompt.value, () => {
346336
prefs.value.prompt = selectedPrompt.value?.name

AiServer/wwwroot/mjs/components/ImageToText.mjs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import { ref, onMounted, inject, watch } from "vue"
22
import { useClient, useFiles } from "@servicestack/vue"
33
import { createErrorStatus } from "@servicestack/client"
4-
import { ImageToText } from "../dtos.mjs"
4+
import {ActiveAiModels, ImageToText} from "../dtos.mjs"
55
import { UiLayout, ThreadStorage, HistoryTitle, HistoryGroups, useUiLayout, icons, Img, acceptedImages } from "../utils.mjs"
66
import FileUpload from "./FileUpload.mjs"
7+
import { renderMarkdown } from "../markdown.mjs"
78

89
export default {
910
components: {
@@ -38,6 +39,24 @@ export default {
3839
</div>
3940
</fieldset>
4041
</div>
42+
<div class="mt-2 grid grid-cols-6 gap-4">
43+
<div class="col-span-2">
44+
<Autocomplete id="model" :options="models" v-model="prefs.model" label="Vision Model"
45+
:match="(x, value) => x.toLowerCase().includes(value.toLowerCase())"
46+
placeholder="Select Vision Model..."
47+
:disabled="!!routes.id" :readonly="!!routes.id">
48+
<template #item="name">
49+
<div class="flex items-center">
50+
<Icon class="h-6 w-6 flex-shrink-0" :src="'/icons/models/' + name" loading="lazy" />
51+
<span class="ml-3 truncate">{{name}}</span>
52+
</div>
53+
</template>
54+
</Autocomplete>
55+
</div>
56+
<div class="col-span-4">
57+
<TextInput id="prompt" v-model="request.prompt" required placeholder="Prompt" />
58+
</div>
59+
</div>
4160
<div class="mt-4 mb-8 flex justify-center">
4261
<PrimaryButton :key="renderKey" type="submit" :disabled="!validPrompt()">
4362
<svg class="-ml-0.5 h-6 w-6 mr-1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path fill="currentColor" d="M11 16V7.85l-2.6 2.6L7 9l5-5l5 5l-1.4 1.45l-2.6-2.6V16zm-5 4q-.825 0-1.412-.587T4 18v-3h2v3h12v-3h2v3q0 .825-.587 1.413T18 20z"/></svg>
@@ -66,8 +85,8 @@ export default {
6685
</div>
6786
</div>
6887
<div>
69-
<div v-for="output in result.response.results" class="relative border border-indigo-600/25 rounded-lg p-2 mb-4 overflow-hidden">
70-
<div class="prose">{{output.text}}</div>
88+
<div v-for="output in result.response.results" class="relative border border-indigo-600/25 rounded-lg p-2 px-4 mb-4 overflow-hidden">
89+
<div v-html="renderMarkdown(output.text ?? '')" class="prose"></div>
7190
</div>
7291
</div>
7392
</div>
@@ -89,6 +108,7 @@ export default {
89108
const refUi = ref()
90109
const refForm = ref()
91110
const refImage = ref()
111+
const models = ref([])
92112
const ui = useUiLayout(refUi)
93113
const renderKey = ref(0)
94114
const { filePathUri, getExt, extSrc, svgToDataUri } = useFiles()
@@ -134,6 +154,10 @@ export default {
134154
let formData = new FormData(refForm.value)
135155
const image = formData.get('image').name
136156

157+
if (prefs.value.model) {
158+
request.value.model = prefs.value.model
159+
}
160+
137161
const api = await client.apiForm(request.value, formData, { jsconfig: 'eccn' })
138162
/** @type {ArtifactGenerationResponse} */
139163
const r = api.response
@@ -215,6 +239,8 @@ export default {
215239
if (thread.value) {
216240
Object.keys(storage.defaults).forEach(k =>
217241
request.value[k] = thread.value[k] ?? storage.defaults[k])
242+
prefs.value.model = thread.value.model ? thread.value.model : ''
243+
request.value.prompt = thread.value.prompt ? thread.value.prompt : ''
218244
}
219245
} else {
220246
thread.value = null
@@ -248,8 +274,20 @@ export default {
248274

249275

250276
watch(() => routes.id, updated)
277+
278+
watch(() => [prefs.value.model], () => {
279+
request.value.prompt = prefs.value.model
280+
? 'Describe this image'
281+
: ''
282+
})
251283

252284
onMounted(async () => {
285+
const api = await client.api(new ActiveAiModels({
286+
provider: 'OllamaAiProvider',
287+
vision: true,
288+
}))
289+
models.value = await api.response.results
290+
models.value.sort((a,b) => a.localeCompare(b))
253291
updated()
254292
})
255293

@@ -259,7 +297,9 @@ export default {
259297
storage,
260298
routes,
261299
client,
300+
prefs,
262301
history,
302+
models,
263303
request,
264304
visibleFields,
265305
validPrompt,
@@ -276,6 +316,7 @@ export default {
276316
getThreadResults,
277317
saveHistoryItem,
278318
removeHistoryItem,
319+
renderMarkdown,
279320
acceptedImages,
280321
renderKey,
281322
}

0 commit comments

Comments
 (0)