Skip to content

Commit 9c74b4a

Browse files
committed
Fix: Add custom image generation capabilities edit
1 parent e4a0bdc commit 9c74b4a

12 files changed

+198
-90
lines changed

src/components/pages/ImageGenerationPage.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import { ImageGenerationManager, ImageGenerationStatus, ImageGenerationHandler }
1212
import { DatabaseIntegrationService } from "../../services/database-integration";
1313
import { ImageGenerationResult } from "../../types/image";
1414
import ImageGenerateHistoryItem from "../image/ImageGenerateHistoryItem";
15+
import { AiServiceProvider } from "../../types/ai-service";
16+
import { AiServiceCapability } from "../../types/ai-service";
1517

1618
export const ImageGenerationPage = () => {
1719
const { t } = useTranslation();

src/services/ai-service.ts

Lines changed: 79 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { Message } from '../types/chat';
44
import { StreamControlHandler } from './streaming-control';
55
import { SETTINGS_CHANGE_EVENT, SettingsService } from './settings-service';
66
import { MCPService } from './mcp-service';
7+
import { AIServiceCapability } from '../types/capabilities';
78

89
export interface ModelOption {
910
id: string;
@@ -35,7 +36,7 @@ export class AIService {
3536
private state: AIState = {
3637
status: 'idle',
3738
error: null,
38-
isCachingModels: false
39+
isCachingModels: false,
3940
};
4041
private listeners: Set<() => void> = new Set();
4142
private modelCache: Map<string, ModelOption[]> = new Map();
@@ -68,11 +69,17 @@ export class AIService {
6869
for (const providerID of Object.keys(settings.providers)) {
6970
const providerSettings = settings.providers[providerID];
7071

71-
if(this.providers.has(providerID)) {
72+
if (this.providers.has(providerID)) {
7273
this.providers.delete(providerID);
73-
this.providers.set(providerID, ProviderFactory.getNewProvider(providerID));
74-
}
75-
else if (providerSettings && providerSettings.apiKey && providerSettings.apiKey.length > 0) {
74+
this.providers.set(
75+
providerID,
76+
ProviderFactory.getNewProvider(providerID)
77+
);
78+
} else if (
79+
providerSettings &&
80+
providerSettings.apiKey &&
81+
providerSettings.apiKey.length > 0
82+
) {
7683
const providerInstance = ProviderFactory.getNewProvider(providerID);
7784
if (providerInstance) {
7885
this.providers.set(providerID, providerInstance);
@@ -90,7 +97,7 @@ export class AIService {
9097
// Refresh models when settings change
9198
this.refreshModels();
9299
};
93-
100+
94101
window.addEventListener(SETTINGS_CHANGE_EVENT, handleSettingsChange);
95102
}
96103

@@ -108,7 +115,7 @@ export class AIService {
108115
* Notify all listeners of state changes
109116
*/
110117
private notifyListeners(): void {
111-
this.listeners.forEach(listener => listener());
118+
this.listeners.forEach((listener) => listener());
112119
}
113120

114121
/**
@@ -125,7 +132,7 @@ export class AIService {
125132
private handleSuccess(): void {
126133
this.setState({
127134
status: 'success',
128-
error: null
135+
error: null,
129136
});
130137
}
131138

@@ -136,7 +143,7 @@ export class AIService {
136143
console.error('AI request error:', error);
137144
this.setState({
138145
status: 'error',
139-
error
146+
error,
140147
});
141148
}
142149

@@ -154,14 +161,14 @@ export class AIService {
154161
if (this.providers.has(name)) {
155162
return this.providers.get(name);
156163
}
157-
164+
158165
// If provider not in cache, try to create it
159166
const provider = ProviderFactory.getNewProvider(name);
160167
if (provider) {
161168
this.providers.set(name, provider);
162169
return provider;
163170
}
164-
171+
165172
return undefined;
166173
}
167174

@@ -172,11 +179,26 @@ export class AIService {
172179
return Array.from(this.providers.values());
173180
}
174181

182+
/**
183+
* Get all providers that support image generation
184+
*/
185+
public getImageGenerationProviders(): AiServiceProvider[] {
186+
const providers = this.getAllProviders();
187+
return providers.filter((provider) => {
188+
// Check if the provider has any models with image generation capability
189+
const models = provider.availableModels || [];
190+
return models.some((model) => {
191+
const capabilities = provider.getModelCapabilities(model.modelId);
192+
return capabilities.includes(AIServiceCapability.ImageGeneration);
193+
});
194+
});
195+
}
196+
175197
/**
176198
* Get a streaming chat completion from the AI
177199
*/
178200
public async getChatCompletion(
179-
messages: Message[],
201+
messages: Message[],
180202
options: CompletionOptions,
181203
streamController: StreamControlHandler
182204
): Promise<Message | null> {
@@ -185,18 +207,25 @@ export class AIService {
185207
const providerName = options.provider;
186208
const modelName = options.model;
187209
const useStreaming = options.stream;
188-
210+
189211
// Get provider instance
190212
const provider = this.getProvider(providerName);
191213

192-
console.log('Provider: ', providerName, ' Model: ', modelName, ' Use streaming: ', useStreaming);
193-
214+
console.log(
215+
'Provider: ',
216+
providerName,
217+
' Model: ',
218+
modelName,
219+
' Use streaming: ',
220+
useStreaming
221+
);
222+
194223
if (!provider) {
195224
throw new Error(`Provider ${providerName} not available`);
196225
}
197-
226+
198227
const result = await provider.getChatCompletion(
199-
messages,
228+
messages,
200229
{
201230
model: modelName,
202231
provider: providerName,
@@ -212,7 +241,6 @@ export class AIService {
212241
},
213242
streamController
214243
);
215-
216244

217245
return result;
218246
} catch (e) {
@@ -221,8 +249,11 @@ export class AIService {
221249
this.handleSuccess();
222250
return null;
223251
}
224-
225-
const error = e instanceof Error ? e : new Error('Unknown error during streaming chat completion');
252+
253+
const error =
254+
e instanceof Error
255+
? e
256+
: new Error('Unknown error during streaming chat completion');
226257
this.handleError(error);
227258
return null;
228259
}
@@ -244,20 +275,20 @@ export class AIService {
244275
throw new Error('Not implemented');
245276

246277
// this.startRequest();
247-
278+
248279
// try {
249280
// const provider = this.getImageGenerationProvider();
250-
281+
251282
// if (!provider) {
252283
// throw new Error('No image generation provider available');
253284
// }
254-
285+
255286
// if (!provider.generateImage) {
256287
// throw new Error(`Provider ${provider.name} does not support image generation`);
257288
// }
258-
289+
259290
// const result = await provider.generateImage(prompt, options);
260-
291+
261292
// this.handleSuccess();
262293
// return result;
263294
// } catch (e) {
@@ -317,62 +348,67 @@ export class AIService {
317348
const cacheKey = 'all_providers';
318349
const cachedTime = this.lastFetchTime.get(cacheKey) || 0;
319350
const now = Date.now();
320-
351+
321352
// Return cached models if they're still valid
322353
if (this.modelCache.has(cacheKey) && now - cachedTime < this.CACHE_TTL) {
323354
return this.modelCache.get(cacheKey) || [];
324355
}
325-
356+
326357
// Otherwise, collect models from all providers
327358
const allModels: ModelOption[] = [];
328359
const providerPromises = [];
329-
360+
330361
for (const provider of this.getAllProviders()) {
331362
providerPromises.push(this.getModelsForProvider(provider.id));
332363
}
333-
364+
334365
const results = await Promise.all(providerPromises);
335-
366+
336367
// Flatten results and filter out duplicates
337-
results.forEach(models => {
368+
results.forEach((models) => {
338369
allModels.push(...models);
339370
});
340-
371+
341372
// Cache and return results
342373
this.modelCache.set(cacheKey, allModels);
343374
this.lastFetchTime.set(cacheKey, now);
344-
375+
345376
return allModels;
346377
}
347378

348379
/**
349380
* Get models for a specific provider
350381
*/
351-
public async getModelsForProvider(providerName: string): Promise<ModelOption[]> {
382+
public async getModelsForProvider(
383+
providerName: string
384+
): Promise<ModelOption[]> {
352385
// Check if we already have a cached result
353386
const cachedTime = this.lastFetchTime.get(providerName) || 0;
354387
const now = Date.now();
355-
388+
356389
// Return cached models if they're still valid
357-
if (this.modelCache.has(providerName) && now - cachedTime < this.CACHE_TTL) {
390+
if (
391+
this.modelCache.has(providerName) &&
392+
now - cachedTime < this.CACHE_TTL
393+
) {
358394
return this.modelCache.get(providerName) || [];
359395
}
360-
396+
361397
// Get provider instance
362398
const provider = this.getProvider(providerName);
363399
if (!provider) {
364400
console.warn(`Provider ${providerName} not available`);
365401
return [];
366402
}
367-
403+
368404
this.setState({ isCachingModels: true });
369-
405+
370406
try {
371407
// Fetch models from provider
372408
const models = await provider.fetchAvailableModels();
373-
409+
374410
// Convert to ModelOption format
375-
const modelOptions: ModelOption[] = models.map(model => ({
411+
const modelOptions: ModelOption[] = models.map((model) => ({
376412
id: model.modelId,
377413
name: model.modelName,
378414
provider: providerName,
@@ -381,7 +417,7 @@ export class AIService {
381417
// Cache results
382418
this.modelCache.set(providerName, modelOptions);
383419
this.lastFetchTime.set(providerName, now);
384-
420+
385421
this.setState({ isCachingModels: false });
386422
return modelOptions;
387423
} catch (error) {
@@ -398,7 +434,7 @@ export class AIService {
398434
// Clear cache
399435
this.modelCache.clear();
400436
this.lastFetchTime.clear();
401-
437+
402438
this.refreshProviders();
403439

404440
// Re-fetch all models

src/services/core/ai-service-provider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export interface AiServiceProvider {
5151
/**
5252
* Get the capabilities of a model with this provider
5353
*/
54-
getModelCapabilities(model: string): AIServiceCapability[];
54+
getModelCapabilities(modelId: string): AIServiceCapability[];
5555

5656
/**
5757
* Fetch available models from the provider API

src/services/providers/anthropic-service.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,18 @@ export class AnthropicService implements AiServiceProvider {
7676
/**
7777
* Get the capabilities of a model with this provider
7878
*/
79-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
80-
getModelCapabilities(model: string): AIServiceCapability[] {
79+
getModelCapabilities(modelId: string): AIServiceCapability[] {
80+
// Get model data by modelId
81+
const models = this.settingsService.getModels(this.name);
82+
const modelData = models.find(x => x.modelId === modelId);
83+
let hasImageGeneration = false;
84+
85+
if(modelData?.modelCapabilities.findIndex(x => x === AIServiceCapability.ImageGeneration) !== -1){
86+
hasImageGeneration = true;
87+
}
88+
8189
return mapModelCapabilities(
82-
false,
90+
hasImageGeneration,
8391
false,
8492
false,
8593
false,

src/services/providers/common-provider-service.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,18 @@ export class CommonProviderHelper implements AiServiceProvider {
104104
/**
105105
* Get the capabilities of a model with this provider
106106
*/
107-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
108-
getModelCapabilities(model: string): AIServiceCapability[] {
107+
getModelCapabilities(modelId: string): AIServiceCapability[] {
108+
// Get model data by modelId
109+
const models = this.settingsService.getModels(this.providerID);
110+
const modelData = models.find(x => x.modelId === modelId);
111+
let hasImageGeneration = false;
112+
113+
if(modelData?.modelCapabilities.findIndex(x => x === AIServiceCapability.ImageGeneration) !== -1){
114+
hasImageGeneration = true;
115+
}
116+
109117
return mapModelCapabilities(
110-
false,
118+
hasImageGeneration,
111119
false,
112120
false,
113121
false,

0 commit comments

Comments
 (0)