diff --git a/src/api.ts b/src/api.ts index b2e48b8..b409299 100644 --- a/src/api.ts +++ b/src/api.ts @@ -32,7 +32,7 @@ import { } from './observability/generation'; import { Step, StepType } from './observability/step'; import { CleanThreadFields, Thread } from './observability/thread'; -import { Prompt } from './prompt-engineering/prompt'; +import { IPromptRollout, Prompt } from './prompt-engineering/prompt'; import { Environment, Maybe, @@ -2080,4 +2080,78 @@ export class API { return new Prompt(this, promptData); } + + /** + * Retrieves a prompt A/B testing rollout by its name. + * + * @param name - The name of the prompt to retrieve. + * @returns A list of prompt rollout versions. + */ + public async getPromptAbTesting( + name: string + ): Promise { + const query = ` + query getPromptLineageRollout($projectId: String, $lineageName: String!) { + promptLineageRollout(projectId: $projectId, lineageName: $lineageName) { + pageInfo { + startCursor + endCursor + } + edges { + node { + version + rollout + } + } + } + } + `; + + const variables = { lineageName: name }; + const result = await this.makeGqlCall(query, variables); + + if (!result.data || !result.data.promptLineageRollout) { + return null; + } + + const response = result.data.promptLineageRollout; + + return response.edges.map((x: any) => x.node); + } + + /** + * Update a prompt A/B testing rollout by its name. + * + * @param name - The name of the prompt to retrieve. + * @param rollouts - A list of prompt rollout versions. + * @returns A list of prompt rollout versions. + */ + public async updatePromptAbTesting(name: string, rollouts: IPromptRollout[]) { + const mutation = ` + mutation updatePromptLineageRollout( + $projectId: String + $name: String! + $rollouts: [PromptVersionRolloutInput!]! + ) { + updatePromptLineageRollout( + projectId: $projectId + name: $name + rollouts: $rollouts + ) { + ok + message + errorCode + } + } + `; + + const variables = { name: name, rollouts }; + const result = await this.makeGqlCall(mutation, variables); + + if (!result.data || !result.data.updatePromptLineageRollout) { + return null; + } + + return result.data.promptLineageRollout; + } } diff --git a/src/prompt-engineering/prompt.ts b/src/prompt-engineering/prompt.ts index e786e41..599e9eb 100644 --- a/src/prompt-engineering/prompt.ts +++ b/src/prompt-engineering/prompt.ts @@ -46,6 +46,11 @@ class PromptFields extends Utils { variables!: IPromptVariableDefinition[]; } +export interface IPromptRollout { + version: number; + rollout: number; +} + export type PromptConstructor = OmitUtils; export class Prompt extends PromptFields { diff --git a/tests/api.test.ts b/tests/api.test.ts index ff78430..41a2a2c 100644 --- a/tests/api.test.ts +++ b/tests/api.test.ts @@ -600,7 +600,7 @@ describe('End to end tests for the SDK', function () { it('should format a prompt with default values', async () => { const prompt = await client.api.getPrompt('Default'); - const formatted = prompt!.format(); + const formatted = prompt!.formatMessages(); const expected = `Hello, this is a test value and this @@ -617,7 +617,7 @@ is a templated list.`; it('should format a prompt with custom values', async () => { const prompt = await client.api.getPrompt('Default'); - const formatted = prompt!.format({ test_var: 'Edited value' }); + const formatted = prompt!.formatMessages({ test_var: 'Edited value' }); const expected = `Hello, this is a Edited value and this @@ -630,5 +630,16 @@ is a templated list.`; expect(formatted.length).toBe(1); expect(formatted[0].content).toBe(expected); }); + + it('should get a prompt A/B testing configuration', async () => { + await client.api.updatePromptAbTesting('Default', [ + { version: 0, rollout: 100 } + ]); + const rollouts = await client.api.getPromptAbTesting('Default'); + expect(rollouts).not.toBeNull(); + expect(rollouts?.length).toBe(1); + expect(rollouts![0].rollout).toBe(100); + expect(rollouts![0].version).toBe(0); + }); }); });