@@ -4,6 +4,7 @@ import { Message } from '../types/chat';
4
4
import { StreamControlHandler } from './streaming-control' ;
5
5
import { SETTINGS_CHANGE_EVENT , SettingsService } from './settings-service' ;
6
6
import { MCPService } from './mcp-service' ;
7
+ import { AIServiceCapability } from '../types/capabilities' ;
7
8
8
9
export interface ModelOption {
9
10
id : string ;
@@ -35,7 +36,7 @@ export class AIService {
35
36
private state : AIState = {
36
37
status : 'idle' ,
37
38
error : null ,
38
- isCachingModels : false
39
+ isCachingModels : false ,
39
40
} ;
40
41
private listeners : Set < ( ) => void > = new Set ( ) ;
41
42
private modelCache : Map < string , ModelOption [ ] > = new Map ( ) ;
@@ -68,11 +69,17 @@ export class AIService {
68
69
for ( const providerID of Object . keys ( settings . providers ) ) {
69
70
const providerSettings = settings . providers [ providerID ] ;
70
71
71
- if ( this . providers . has ( providerID ) ) {
72
+ if ( this . providers . has ( providerID ) ) {
72
73
this . providers . delete ( providerID ) ;
73
- this . providers . set ( providerID , ProviderFactory . getNewProvider ( providerID ) ) ;
74
- }
75
- else if ( providerSettings && providerSettings . apiKey && providerSettings . apiKey . length > 0 ) {
74
+ this . providers . set (
75
+ providerID ,
76
+ ProviderFactory . getNewProvider ( providerID )
77
+ ) ;
78
+ } else if (
79
+ providerSettings &&
80
+ providerSettings . apiKey &&
81
+ providerSettings . apiKey . length > 0
82
+ ) {
76
83
const providerInstance = ProviderFactory . getNewProvider ( providerID ) ;
77
84
if ( providerInstance ) {
78
85
this . providers . set ( providerID , providerInstance ) ;
@@ -90,7 +97,7 @@ export class AIService {
90
97
// Refresh models when settings change
91
98
this . refreshModels ( ) ;
92
99
} ;
93
-
100
+
94
101
window . addEventListener ( SETTINGS_CHANGE_EVENT , handleSettingsChange ) ;
95
102
}
96
103
@@ -108,7 +115,7 @@ export class AIService {
108
115
* Notify all listeners of state changes
109
116
*/
110
117
private notifyListeners ( ) : void {
111
- this . listeners . forEach ( listener => listener ( ) ) ;
118
+ this . listeners . forEach ( ( listener ) => listener ( ) ) ;
112
119
}
113
120
114
121
/**
@@ -125,7 +132,7 @@ export class AIService {
125
132
private handleSuccess ( ) : void {
126
133
this . setState ( {
127
134
status : 'success' ,
128
- error : null
135
+ error : null ,
129
136
} ) ;
130
137
}
131
138
@@ -136,7 +143,7 @@ export class AIService {
136
143
console . error ( 'AI request error:' , error ) ;
137
144
this . setState ( {
138
145
status : 'error' ,
139
- error
146
+ error,
140
147
} ) ;
141
148
}
142
149
@@ -154,14 +161,14 @@ export class AIService {
154
161
if ( this . providers . has ( name ) ) {
155
162
return this . providers . get ( name ) ;
156
163
}
157
-
164
+
158
165
// If provider not in cache, try to create it
159
166
const provider = ProviderFactory . getNewProvider ( name ) ;
160
167
if ( provider ) {
161
168
this . providers . set ( name , provider ) ;
162
169
return provider ;
163
170
}
164
-
171
+
165
172
return undefined ;
166
173
}
167
174
@@ -172,11 +179,26 @@ export class AIService {
172
179
return Array . from ( this . providers . values ( ) ) ;
173
180
}
174
181
182
+ /**
183
+ * Get all providers that support image generation
184
+ */
185
+ public getImageGenerationProviders ( ) : AiServiceProvider [ ] {
186
+ const providers = this . getAllProviders ( ) ;
187
+ return providers . filter ( ( provider ) => {
188
+ // Check if the provider has any models with image generation capability
189
+ const models = provider . availableModels || [ ] ;
190
+ return models . some ( ( model ) => {
191
+ const capabilities = provider . getModelCapabilities ( model . modelId ) ;
192
+ return capabilities . includes ( AIServiceCapability . ImageGeneration ) ;
193
+ } ) ;
194
+ } ) ;
195
+ }
196
+
175
197
/**
176
198
* Get a streaming chat completion from the AI
177
199
*/
178
200
public async getChatCompletion (
179
- messages : Message [ ] ,
201
+ messages : Message [ ] ,
180
202
options : CompletionOptions ,
181
203
streamController : StreamControlHandler
182
204
) : Promise < Message | null > {
@@ -185,18 +207,25 @@ export class AIService {
185
207
const providerName = options . provider ;
186
208
const modelName = options . model ;
187
209
const useStreaming = options . stream ;
188
-
210
+
189
211
// Get provider instance
190
212
const provider = this . getProvider ( providerName ) ;
191
213
192
- console . log ( 'Provider: ' , providerName , ' Model: ' , modelName , ' Use streaming: ' , useStreaming ) ;
193
-
214
+ console . log (
215
+ 'Provider: ' ,
216
+ providerName ,
217
+ ' Model: ' ,
218
+ modelName ,
219
+ ' Use streaming: ' ,
220
+ useStreaming
221
+ ) ;
222
+
194
223
if ( ! provider ) {
195
224
throw new Error ( `Provider ${ providerName } not available` ) ;
196
225
}
197
-
226
+
198
227
const result = await provider . getChatCompletion (
199
- messages ,
228
+ messages ,
200
229
{
201
230
model : modelName ,
202
231
provider : providerName ,
@@ -212,7 +241,6 @@ export class AIService {
212
241
} ,
213
242
streamController
214
243
) ;
215
-
216
244
217
245
return result ;
218
246
} catch ( e ) {
@@ -221,8 +249,11 @@ export class AIService {
221
249
this . handleSuccess ( ) ;
222
250
return null ;
223
251
}
224
-
225
- const error = e instanceof Error ? e : new Error ( 'Unknown error during streaming chat completion' ) ;
252
+
253
+ const error =
254
+ e instanceof Error
255
+ ? e
256
+ : new Error ( 'Unknown error during streaming chat completion' ) ;
226
257
this . handleError ( error ) ;
227
258
return null ;
228
259
}
@@ -244,20 +275,20 @@ export class AIService {
244
275
throw new Error ( 'Not implemented' ) ;
245
276
246
277
// this.startRequest();
247
-
278
+
248
279
// try {
249
280
// const provider = this.getImageGenerationProvider();
250
-
281
+
251
282
// if (!provider) {
252
283
// throw new Error('No image generation provider available');
253
284
// }
254
-
285
+
255
286
// if (!provider.generateImage) {
256
287
// throw new Error(`Provider ${provider.name} does not support image generation`);
257
288
// }
258
-
289
+
259
290
// const result = await provider.generateImage(prompt, options);
260
-
291
+
261
292
// this.handleSuccess();
262
293
// return result;
263
294
// } catch (e) {
@@ -317,62 +348,67 @@ export class AIService {
317
348
const cacheKey = 'all_providers' ;
318
349
const cachedTime = this . lastFetchTime . get ( cacheKey ) || 0 ;
319
350
const now = Date . now ( ) ;
320
-
351
+
321
352
// Return cached models if they're still valid
322
353
if ( this . modelCache . has ( cacheKey ) && now - cachedTime < this . CACHE_TTL ) {
323
354
return this . modelCache . get ( cacheKey ) || [ ] ;
324
355
}
325
-
356
+
326
357
// Otherwise, collect models from all providers
327
358
const allModels : ModelOption [ ] = [ ] ;
328
359
const providerPromises = [ ] ;
329
-
360
+
330
361
for ( const provider of this . getAllProviders ( ) ) {
331
362
providerPromises . push ( this . getModelsForProvider ( provider . id ) ) ;
332
363
}
333
-
364
+
334
365
const results = await Promise . all ( providerPromises ) ;
335
-
366
+
336
367
// Flatten results and filter out duplicates
337
- results . forEach ( models => {
368
+ results . forEach ( ( models ) => {
338
369
allModels . push ( ...models ) ;
339
370
} ) ;
340
-
371
+
341
372
// Cache and return results
342
373
this . modelCache . set ( cacheKey , allModels ) ;
343
374
this . lastFetchTime . set ( cacheKey , now ) ;
344
-
375
+
345
376
return allModels ;
346
377
}
347
378
348
379
/**
349
380
* Get models for a specific provider
350
381
*/
351
- public async getModelsForProvider ( providerName : string ) : Promise < ModelOption [ ] > {
382
+ public async getModelsForProvider (
383
+ providerName : string
384
+ ) : Promise < ModelOption [ ] > {
352
385
// Check if we already have a cached result
353
386
const cachedTime = this . lastFetchTime . get ( providerName ) || 0 ;
354
387
const now = Date . now ( ) ;
355
-
388
+
356
389
// Return cached models if they're still valid
357
- if ( this . modelCache . has ( providerName ) && now - cachedTime < this . CACHE_TTL ) {
390
+ if (
391
+ this . modelCache . has ( providerName ) &&
392
+ now - cachedTime < this . CACHE_TTL
393
+ ) {
358
394
return this . modelCache . get ( providerName ) || [ ] ;
359
395
}
360
-
396
+
361
397
// Get provider instance
362
398
const provider = this . getProvider ( providerName ) ;
363
399
if ( ! provider ) {
364
400
console . warn ( `Provider ${ providerName } not available` ) ;
365
401
return [ ] ;
366
402
}
367
-
403
+
368
404
this . setState ( { isCachingModels : true } ) ;
369
-
405
+
370
406
try {
371
407
// Fetch models from provider
372
408
const models = await provider . fetchAvailableModels ( ) ;
373
-
409
+
374
410
// Convert to ModelOption format
375
- const modelOptions : ModelOption [ ] = models . map ( model => ( {
411
+ const modelOptions : ModelOption [ ] = models . map ( ( model ) => ( {
376
412
id : model . modelId ,
377
413
name : model . modelName ,
378
414
provider : providerName ,
@@ -381,7 +417,7 @@ export class AIService {
381
417
// Cache results
382
418
this . modelCache . set ( providerName , modelOptions ) ;
383
419
this . lastFetchTime . set ( providerName , now ) ;
384
-
420
+
385
421
this . setState ( { isCachingModels : false } ) ;
386
422
return modelOptions ;
387
423
} catch ( error ) {
@@ -398,7 +434,7 @@ export class AIService {
398
434
// Clear cache
399
435
this . modelCache . clear ( ) ;
400
436
this . lastFetchTime . clear ( ) ;
401
-
437
+
402
438
this . refreshProviders ( ) ;
403
439
404
440
// Re-fetch all models
0 commit comments