Skip to content

Commit 01bb14a

Browse files
committed
Fix: Custom Provider Image Generation
1 parent 9c74b4a commit 01bb14a

11 files changed

+87
-40
lines changed

src/components/pages/ImageGenerationPage.tsx

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@ import { ChevronDown, RefreshCw, Settings, Zap } from "lucide-react";
77
import { useTranslation } from "react-i18next";
88
import { AIService } from "../../services/ai-service";
99
import { OPENAI_PROVIDER_NAME } from "../../services/providers/openai-service";
10-
import { FORGE_PROVIDER_NAME as TENSORBLOCK_PROVIDER_NAME } from "../../services/providers/forge-service";
1110
import { ImageGenerationManager, ImageGenerationStatus, ImageGenerationHandler } from "../../services/image-generation-handler";
1211
import { DatabaseIntegrationService } from "../../services/database-integration";
1312
import { ImageGenerationResult } from "../../types/image";
1413
import ImageGenerateHistoryItem from "../image/ImageGenerateHistoryItem";
15-
import { AiServiceProvider } from "../../types/ai-service";
16-
import { AiServiceCapability } from "../../types/ai-service";
14+
1715

1816
export const ImageGenerationPage = () => {
1917
const { t } = useTranslation();
@@ -33,6 +31,8 @@ export const ImageGenerationPage = () => {
3331
const [isLoadingHistory, setIsLoadingHistory] = useState(true);
3432
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
3533
const [selectedProvider, setSelectedProvider] = useState(OPENAI_PROVIDER_NAME);
34+
const [selectedModel, setSelectedModel] = useState("dall-e-3");
35+
const [availableProviders, setAvailableProviders] = useState<{id: string, name: string}[]>([]);
3636

3737
const settingsButtonRef = useRef<HTMLButtonElement>(null);
3838
const settingsPopupRef = useRef<HTMLDivElement>(null);
@@ -62,6 +62,31 @@ export const ImageGenerationPage = () => {
6262
}
6363
}, []);
6464

65+
// Load available image generation providers
66+
const loadImageGenerationProviders = useCallback(async () => {
67+
const aiService = AIService.getInstance();
68+
const providers = aiService.getImageGenerationProviders();
69+
70+
const providerOptions = providers.map(provider => ({
71+
id: provider.id,
72+
name: provider.name || provider.id
73+
}));
74+
75+
setAvailableProviders(providerOptions);
76+
77+
// Set default provider if none is selected or current one isn't available
78+
if (!selectedProvider || !providerOptions.some(p => p.id === selectedProvider)) {
79+
if (providerOptions.length > 0) {
80+
setSelectedProvider(providerOptions[0].id);
81+
}
82+
}
83+
}, [selectedProvider]);
84+
85+
const handleGetProviderNameById = (id: string) => {
86+
const provider = availableProviders.find(p => p.id === id);
87+
return provider ? provider.name : id;
88+
}
89+
6590
// Initialize image generation manager and load settings
6691
useEffect(() => {
6792
const initialize = async () => {
@@ -88,6 +113,12 @@ export const ImageGenerationPage = () => {
88113
if (settings.imageGenerationProvider) {
89114
setSelectedProvider(settings.imageGenerationProvider);
90115
}
116+
if (settings.imageGenerationModel) {
117+
setSelectedModel(settings.imageGenerationModel);
118+
}
119+
120+
// Load available providers
121+
await loadImageGenerationProviders();
91122

92123
// Load image generation history from database
93124
await refreshImageHistory();
@@ -100,7 +131,19 @@ export const ImageGenerationPage = () => {
100131
};
101132

102133
initialize();
103-
}, [refreshImageHistory]);
134+
}, [refreshImageHistory, loadImageGenerationProviders]);
135+
136+
// Listen for settings changes
137+
useEffect(() => {
138+
const handleSettingsChange = () => {
139+
loadImageGenerationProviders();
140+
};
141+
142+
window.addEventListener(SETTINGS_CHANGE_EVENT, handleSettingsChange);
143+
return () => {
144+
window.removeEventListener(SETTINGS_CHANGE_EVENT, handleSettingsChange);
145+
};
146+
}, [loadImageGenerationProviders]);
104147

105148
// Check if API key is available
106149
useEffect(() => {
@@ -183,14 +226,9 @@ export const ImageGenerationPage = () => {
183226
setError(null);
184227

185228
try {
186-
let providerService;
187-
188229
// Get the appropriate service based on selected provider
189-
if (selectedProvider === TENSORBLOCK_PROVIDER_NAME) {
190-
providerService = AIService.getInstance().getProvider(TENSORBLOCK_PROVIDER_NAME);
191-
} else {
192-
providerService = AIService.getInstance().getProvider(OPENAI_PROVIDER_NAME);
193-
}
230+
const aiService = AIService.getInstance();
231+
const providerService = aiService.getProvider(selectedProvider);
194232

195233
if (!providerService) {
196234
throw new Error(`${selectedProvider} service not available`);
@@ -201,10 +239,10 @@ export const ImageGenerationPage = () => {
201239
const handler = imageManager.createHandler({
202240
prompt: prompt,
203241
seed: randomSeed,
204-
number: imageCount,
242+
number: 1,
205243
aspectRatio: aspectRatio,
206244
provider: selectedProvider,
207-
model: "dall-e-3",
245+
model: selectedModel,
208246
});
209247

210248
// Set status to generating
@@ -480,7 +518,7 @@ export const ImageGenerationPage = () => {
480518
</div>
481519

482520
<div className="mb-4">
483-
<label className="flex items-center block mb-2 text-sm font-medium text-gray-700">
521+
<label className="flex items-center mb-2 text-sm font-medium text-gray-700">
484522
{t("imageGeneration.randomSeed")}
485523
<div
486524
className="flex items-center justify-center w-4 h-4 ml-1 text-xs text-gray-500 bg-gray-200 rounded-full cursor-help"
@@ -507,15 +545,15 @@ export const ImageGenerationPage = () => {
507545
</div>
508546

509547
<div className="mb-4">
510-
<label className="flex items-center block mb-2 text-sm font-medium text-gray-700">
548+
<label className="flex items-center mb-2 text-sm font-medium text-gray-700">
511549
{t("imageGeneration.provider")}
512550
</label>
513551
<div className="relative">
514552
<button
515553
className="flex items-center justify-between w-full p-3 text-left provider-dropdown-toggle input-box"
516554
onClick={toggleProviderDropdown}
517555
>
518-
<span>{selectedProvider}</span>
556+
<span>{handleGetProviderNameById(selectedProvider)}</span>
519557
<ChevronDown size={18} className="text-gray-500" />
520558
</button>
521559

@@ -525,30 +563,31 @@ export const ImageGenerationPage = () => {
525563
className="absolute z-20 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg"
526564
>
527565
<ul className="py-1">
528-
<li
529-
className={`px-3 py-2 cursor-pointer hover:bg-gray-100 ${
530-
selectedProvider === OPENAI_PROVIDER_NAME ? 'bg-gray-50 font-medium' : ''
531-
}`}
532-
onClick={() => handleProviderSelect(OPENAI_PROVIDER_NAME)}
533-
>
534-
OpenAI
535-
</li>
536-
<li
537-
className={`px-3 py-2 cursor-pointer hover:bg-gray-100 ${
538-
selectedProvider === TENSORBLOCK_PROVIDER_NAME ? 'bg-gray-50 font-medium' : ''
539-
}`}
540-
onClick={() => handleProviderSelect(TENSORBLOCK_PROVIDER_NAME)}
541-
>
542-
TensorBlock
543-
</li>
566+
{availableProviders.length > 0 ? (
567+
availableProviders.map(provider => (
568+
<li
569+
key={provider.id}
570+
className={`px-3 py-2 cursor-pointer hover:bg-gray-100 ${
571+
selectedProvider === provider.id ? 'bg-gray-50 font-medium' : ''
572+
}`}
573+
onClick={() => handleProviderSelect(provider.id)}
574+
>
575+
{provider.name}
576+
</li>
577+
))
578+
) : (
579+
<li className="px-3 py-2 text-gray-500">
580+
{t("chat.noImageProvidersAvailable")}
581+
</li>
582+
)}
544583
</ul>
545584
</div>
546585
)}
547586
</div>
548587
</div>
549588

550589
<div className="mb-4">
551-
<label className="flex items-center block mb-2 text-sm font-medium text-gray-700">
590+
<label className="flex items-center mb-2 text-sm font-medium text-gray-700">
552591
{t("imageGeneration.model")}
553592
</label>
554593
<div className="relative">

src/services/providers/anthropic-service.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ export class AnthropicService implements AiServiceProvider {
2929
this.settingsService = SettingsService.getInstance();
3030
const providerSettings = this.settingsService.getProviderSettings(ANTHROPIC_PROVIDER_NAME);
3131

32+
this.apiModels = this.settingsService.getModels(ANTHROPIC_PROVIDER_NAME);
33+
3234
this._apiKey = providerSettings.apiKey || '';
3335

3436
this.anthropic = new Anthropic({

src/services/providers/common-provider-service.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ export class CommonProviderHelper implements AiServiceProvider {
5555

5656
this._apiKey = providerSettings.apiKey || '';
5757

58+
this.apiModels = this.settingsService.getModels(providerName);
59+
5860
this.ProviderInstance = this.createClient();
5961
}
6062

@@ -106,7 +108,7 @@ export class CommonProviderHelper implements AiServiceProvider {
106108
*/
107109
getModelCapabilities(modelId: string): AIServiceCapability[] {
108110
// Get model data by modelId
109-
const models = this.settingsService.getModels(this.providerID);
111+
const models = this.settingsService.getModels(this.providerName);
110112
const modelData = models.find(x => x.modelId === modelId);
111113
let hasImageGeneration = false;
112114

src/services/providers/custom-service.ts

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ export class CustomService implements AiServiceProvider {
3838
this.baseURL = `${baseURL}/${this.apiVersion}`;
3939
this.apiKey = apiKey;
4040

41+
this.apiModels = this.settingsService.getModels(this.providerID);
42+
4143
this.openAIProvider = createOpenAI({
4244
apiKey: this.apiKey,
4345
compatibility: 'compatible',
@@ -51,9 +53,6 @@ export class CustomService implements AiServiceProvider {
5153
*/
5254
get name(): string {
5355
const providerSettings = this.settingsService.getProviderSettings(this.providerID);
54-
const error = new Error('Custom provider settings: ' + JSON.stringify(providerSettings));
55-
console.log(error);
56-
console.log('Provider Name: ', providerSettings.providerName);
5756
return providerSettings.providerName;
5857
}
5958

@@ -88,8 +87,7 @@ export class CustomService implements AiServiceProvider {
8887
// eslint-disable-next-line @typescript-eslint/no-unused-vars
8988
getModelCapabilities(modelId: string): AIServiceCapability[] {
9089
// Get model data by modelId
91-
const models = this.settingsService.getModels(this.providerID);
92-
const modelData = models.find(x => x.modelId === modelId);
90+
const modelData = this.apiModels.find(x => x.modelId === modelId);
9391
let hasImageGeneration = false;
9492

9593
if(modelData?.modelCapabilities.findIndex(x => x === AIServiceCapability.ImageGeneration) !== -1){

src/services/providers/fireworks-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export class FireworksService implements AiServiceProvider {
2424
*/
2525
constructor() {
2626
this.settingsService = SettingsService.getInstance();
27+
this.apiModels = this.settingsService.getModels(FIREWORKS_PROVIDER_NAME);
2728
this.commonProviderHelper = new CommonProviderHelper(FIREWORKS_PROVIDER_NAME, this.createClient);
2829
}
2930

src/services/providers/forge-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export class ForgeService implements AiServiceProvider {
2424
*/
2525
constructor() {
2626
this.settingsService = SettingsService.getInstance();
27+
this.apiModels = this.settingsService.getModels(FORGE_PROVIDER_NAME);
2728
this.commonProviderHelper = new CommonProviderHelper(FORGE_PROVIDER_NAME, this.createClient);
2829
}
2930

src/services/providers/gemini-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export class GeminiService implements AiServiceProvider {
2626
*/
2727
constructor() {
2828
this.settingsService = SettingsService.getInstance();
29+
this.apiModels = this.settingsService.getModels(GEMINI_PROVIDER_NAME);
2930
const providerSettings = this.settingsService.getProviderSettings(GEMINI_PROVIDER_NAME);
3031

3132
this._apiKey = providerSettings.apiKey || '';

src/services/providers/openai-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export class OpenAIService implements AiServiceProvider {
2424
*/
2525
constructor() {
2626
this.settingsService = SettingsService.getInstance();
27+
this.apiModels = this.settingsService.getModels(OPENAI_PROVIDER_NAME);
2728
this.commonProviderHelper = new CommonProviderHelper(OPENAI_PROVIDER_NAME, this.createClient);
2829
}
2930

src/services/providers/openrouter-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export class OpenRouterService implements AiServiceProvider {
2525
*/
2626
constructor() {
2727
this.settingsService = SettingsService.getInstance();
28+
this.apiModels = this.settingsService.getModels(OPENROUTER_PROVIDER_NAME);
2829
const providerSettings = this.settingsService.getProviderSettings(OPENROUTER_PROVIDER_NAME);
2930

3031
this._apiKey = providerSettings.apiKey || '';

src/services/providers/together-service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export class TogetherService implements AiServiceProvider {
2424
*/
2525
constructor() {
2626
this.settingsService = SettingsService.getInstance();
27+
this.apiModels = this.settingsService.getModels(TOGETHER_PROVIDER_NAME);
2728
this.commonProviderHelper = new CommonProviderHelper(TOGETHER_PROVIDER_NAME, this.createClient);
2829
}
2930

src/types/capabilities.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export const mapModelCapabilities = (
3737
];
3838

3939
if (supportsImages) {
40-
capabilities.push(AIServiceCapability.VisionAnalysis);
40+
capabilities.push(AIServiceCapability.ImageGeneration);
4141
}
4242

4343
if (supportsAudio) {

0 commit comments

Comments
 (0)