Skip to content

Commit bcfa451

Browse files
authored
Merge pull request #1064 from unsync/feature/cloudflare-image-generate
feat: adds image generation capability to Workers AI API
2 parents 7fa8ce1 + 0380772 commit bcfa451

File tree

7 files changed

+133
-100
lines changed

7 files changed

+133
-100
lines changed

src/providers/workers-ai/api.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ const WorkersAiAPIConfig: ProviderAPIConfig = {
1818
case 'chatComplete': {
1919
return `/${model}`;
2020
}
21-
case 'embed':
21+
case 'embed': {
2222
return `/${model}`;
23+
}
24+
case 'imageGenerate': {
25+
return `/${model}`;
26+
}
2327
default:
2428
return '';
2529
}

src/providers/workers-ai/chatComplete.ts

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ import {
55
ErrorResponse,
66
ProviderConfig,
77
} from '../types';
8+
import { generateInvalidProviderResponseError } from '../utils';
89
import {
9-
generateErrorResponse,
10-
generateInvalidProviderResponseError,
11-
} from '../utils';
10+
WorkersAiErrorResponse,
11+
WorkersAiErrorResponseTransform,
12+
} from './utils';
1213

1314
export const WorkersAiChatCompleteConfig: ProviderConfig = {
1415
messages: {
@@ -36,16 +37,6 @@ export const WorkersAiChatCompleteConfig: ProviderConfig = {
3637
},
3738
};
3839

39-
export interface WorkersAiErrorObject {
40-
code: string;
41-
message: string;
42-
}
43-
44-
interface WorkersAiErrorResponse {
45-
success: boolean;
46-
errors: WorkersAiErrorObject[];
47-
}
48-
4940
interface WorkersAiChatCompleteResponse {
5041
result: {
5142
response: string;
@@ -60,26 +51,6 @@ interface WorkersAiChatCompleteStreamResponse {
6051
p?: string;
6152
}
6253

63-
export const WorkersAiErrorResponseTransform: (
64-
response: WorkersAiErrorResponse
65-
) => ErrorResponse | undefined = (response) => {
66-
if ('errors' in response) {
67-
return generateErrorResponse(
68-
{
69-
message: response.errors
70-
?.map((error) => `Error ${error.code}:${error.message}`)
71-
.join(', '),
72-
type: null,
73-
param: null,
74-
code: null,
75-
},
76-
WORKERS_AI
77-
);
78-
}
79-
80-
return undefined;
81-
};
82-
8354
// TODO: cloudflare do not return the usage
8455
export const WorkersAiChatCompleteResponseTransform: (
8556
response: WorkersAiChatCompleteResponse | WorkersAiErrorResponse,

src/providers/workers-ai/complete.ts

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import { Params } from '../../types/requestBody';
22
import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types';
33
import { WORKERS_AI } from '../../globals';
4+
import { generateInvalidProviderResponseError } from '../utils';
45
import {
5-
generateErrorResponse,
6-
generateInvalidProviderResponseError,
7-
} from '../utils';
6+
WorkersAiErrorResponse,
7+
WorkersAiErrorResponseTransform,
8+
} from './utils';
89

910
export const WorkersAiCompleteConfig: ProviderConfig = {
1011
prompt: {
@@ -24,16 +25,6 @@ export const WorkersAiCompleteConfig: ProviderConfig = {
2425
},
2526
};
2627

27-
export interface WorkersAiErrorObject {
28-
code: string;
29-
message: string;
30-
}
31-
32-
interface WorkersAiErrorResponse {
33-
success: boolean;
34-
errors: WorkersAiErrorObject[];
35-
}
36-
3728
interface WorkersAiCompleteResponse {
3829
result: {
3930
response: string;
@@ -48,26 +39,6 @@ interface WorkersAiCompleteStreamResponse {
4839
p?: string;
4940
}
5041

51-
export const WorkersAiErrorResponseTransform: (
52-
response: WorkersAiErrorResponse
53-
) => ErrorResponse | undefined = (response) => {
54-
if ('errors' in response) {
55-
return generateErrorResponse(
56-
{
57-
message: response.errors
58-
?.map((error) => `Error ${error.code}:${error.message}`)
59-
.join(', '),
60-
type: null,
61-
param: null,
62-
code: null,
63-
},
64-
WORKERS_AI
65-
);
66-
}
67-
68-
return undefined;
69-
};
70-
7142
export const WorkersAiCompleteResponseTransform: (
7243
response: WorkersAiCompleteResponse | WorkersAiErrorResponse,
7344
responseStatus: number,

src/providers/workers-ai/embed.ts

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import { WORKERS_AI } from '../../globals';
22
import { EmbedParams, EmbedResponse } from '../../types/embedRequestBody';
33
import { ErrorResponse, ProviderConfig } from '../types';
4+
import { generateInvalidProviderResponseError } from '../utils';
45
import {
5-
generateErrorResponse,
6-
generateInvalidProviderResponseError,
7-
} from '../utils';
6+
WorkersAiErrorResponse,
7+
WorkersAiErrorResponseTransform,
8+
} from './utils';
89

910
export const WorkersAiEmbedConfig: ProviderConfig = {
1011
input: {
@@ -20,36 +21,6 @@ export const WorkersAiEmbedConfig: ProviderConfig = {
2021
},
2122
};
2223

23-
export interface WorkersAiErrorObject {
24-
code: string;
25-
message: string;
26-
}
27-
28-
interface WorkersAiErrorResponse {
29-
success: boolean;
30-
errors: WorkersAiErrorObject[];
31-
}
32-
33-
export const WorkersAiErrorResponseTransform: (
34-
response: WorkersAiErrorResponse
35-
) => ErrorResponse | undefined = (response) => {
36-
if ('errors' in response) {
37-
return generateErrorResponse(
38-
{
39-
message: response.errors
40-
?.map((error) => `Error ${error.code}:${error.message}`)
41-
.join(', '),
42-
type: null,
43-
param: null,
44-
code: null,
45-
},
46-
WORKERS_AI
47-
);
48-
}
49-
50-
return undefined;
51-
};
52-
5324
/**
5425
* The structure of the CohereEmbedResponse.
5526
* @interface
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import { WORKERS_AI } from '../../globals';
2+
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
3+
import { generateInvalidProviderResponseError } from '../utils';
4+
import {
5+
WorkersAiErrorResponse,
6+
WorkersAiErrorResponseTransform,
7+
} from './utils';
8+
9+
interface WorkersAiImageGenerateResponse extends ImageGenerateResponse {
10+
result: {
11+
image?: string;
12+
};
13+
success: boolean;
14+
errors: string[];
15+
messages: string[];
16+
}
17+
18+
export const WorkersAiImageGenerateConfig: ProviderConfig = {
19+
prompt: {
20+
param: 'prompt',
21+
required: true,
22+
},
23+
negative_prompt: {
24+
param: 'negative_prompt',
25+
},
26+
steps: [
27+
{
28+
param: 'num_steps',
29+
},
30+
{
31+
param: 'steps',
32+
},
33+
],
34+
size: [
35+
{
36+
param: 'height',
37+
transform: (params: any) =>
38+
parseInt(params.size.toLowerCase().split('x')[1]),
39+
},
40+
{
41+
param: 'width',
42+
transform: (params: any) =>
43+
parseInt(params.size.toLowerCase().split('x')[0]),
44+
},
45+
],
46+
seed: {
47+
param: 'seed',
48+
},
49+
};
50+
51+
export const WorkersAiImageGenerateResponseTransform: (
52+
response: WorkersAiImageGenerateResponse | WorkersAiErrorResponse,
53+
responseStatus: number
54+
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
55+
if (responseStatus !== 200) {
56+
const errorResponse = WorkersAiErrorResponseTransform(
57+
response as WorkersAiErrorResponse
58+
);
59+
if (errorResponse) return errorResponse;
60+
}
61+
62+
// the imageGenerate response is not always the same for all Cloudflare image models
63+
// we currently only support the image model that returns a base64 image
64+
if ('result' in response && 'image' in response.result) {
65+
return {
66+
created: Math.floor(Date.now() / 1000),
67+
data: [
68+
{
69+
b64_json: response.result.image,
70+
},
71+
],
72+
provider: WORKERS_AI,
73+
};
74+
}
75+
76+
return generateInvalidProviderResponseError(response, WORKERS_AI);
77+
};

src/providers/workers-ai/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,24 @@ import {
1111
WorkersAiCompleteStreamChunkTransform,
1212
} from './complete';
1313
import { WorkersAiEmbedConfig, WorkersAiEmbedResponseTransform } from './embed';
14+
import {
15+
WorkersAiImageGenerateConfig,
16+
WorkersAiImageGenerateResponseTransform,
17+
} from './imageGenerate';
1418

1519
const WorkersAiConfig: ProviderConfigs = {
1620
complete: WorkersAiCompleteConfig,
1721
chatComplete: WorkersAiChatCompleteConfig,
1822
api: WorkersAiAPIConfig,
1923
embed: WorkersAiEmbedConfig,
24+
imageGenerate: WorkersAiImageGenerateConfig,
2025
responseTransforms: {
2126
'stream-complete': WorkersAiCompleteStreamChunkTransform,
2227
complete: WorkersAiCompleteResponseTransform,
2328
chatComplete: WorkersAiChatCompleteResponseTransform,
2429
'stream-chatComplete': WorkersAiChatCompleteStreamChunkTransform,
2530
embed: WorkersAiEmbedResponseTransform,
31+
imageGenerate: WorkersAiImageGenerateResponseTransform,
2632
},
2733
};
2834

src/providers/workers-ai/utils.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import { ErrorResponse } from '../types';
2+
import { generateErrorResponse } from '../utils';
3+
import { WORKERS_AI } from '../../globals';
4+
5+
export interface WorkersAiErrorResponse {
6+
success: boolean;
7+
errors: WorkersAiErrorObject[];
8+
}
9+
10+
export interface WorkersAiErrorObject {
11+
code: string;
12+
message: string;
13+
}
14+
15+
export const WorkersAiErrorResponseTransform: (
16+
response: WorkersAiErrorResponse
17+
) => ErrorResponse | undefined = (response) => {
18+
if ('errors' in response) {
19+
return generateErrorResponse(
20+
{
21+
message: response.errors
22+
?.map((error) => `Error ${error.code}:${error.message}`)
23+
.join(', '),
24+
type: null,
25+
param: null,
26+
code: null,
27+
},
28+
WORKERS_AI
29+
);
30+
}
31+
32+
return undefined;
33+
};

0 commit comments

Comments
 (0)