diff --git a/package.json b/package.json index b30acad..397d2e1 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@literalai/client", - "version": "0.0.602", + "version": "0.1.0", "description": "", "exports": { ".": { diff --git a/src/api.ts b/src/api.ts index a6d64d3..7d510ab 100644 --- a/src/api.ts +++ b/src/api.ts @@ -1828,6 +1828,65 @@ export class API { return Object.values(result.data).map((x: any) => new DatasetItem(x)); } + /** + * Creates a prompt variation for an experiment. + * This variation is not an official version until manually saved. + * + * @param name The name of the prompt to retrieve or create. + * @param templateMessages A list of template messages for the prompt. + * @param settings Optional settings for the prompt. + * @param tools Optional tools for the prompt. + * @returns The prompt variant id to link with the experiment. + */ + public async createPromptVariant( + name: string, + templateMessages: IGenerationMessage[], + settings?: Maybe>, + tools?: Maybe> + ): Promise { + const mutation = `mutation createPromptExperiment( + $fromLineageId: String + $fromVersion: Int + $scoreTemplateId: String + $templateMessages: Json + $settings: Json + $tools: Json + $variables: Json + ) { + createPromptExperiment( + fromLineageId: $fromLineageId + fromVersion: $fromVersion + scoreTemplateId: $scoreTemplateId + templateMessages: $templateMessages + settings: $settings + tools: $tools + variables: $variables + ) { + id + fromLineageId + fromVersion + scoreTemplateId + projectId + projectUserId + tools + settings + variables + templateMessages + } + } + `; + + const lineage = await this.getPromptLineageByName(name); + const result = await this.makeGqlCall(mutation, { + fromLineageId: lineage?.id, + templateMessages, + settings, + tools + }); + + return result.data.createPromptExperiment?.id; + } + /** * Creates a new dataset experiment. * @param datasetExperiment @@ -1840,12 +1899,12 @@ export class API { public async createExperiment(datasetExperiment: { name: string; datasetId?: string; - promptId?: string; + promptVariantId?: string; params?: Record | Array>; }) { const query = ` - mutation CreateDatasetExperiment($name: String!, $datasetId: String $promptId: String, $params: Json) { - createDatasetExperiment(name: $name, datasetId: $datasetId, promptId: $promptId, params: $params) { + mutation CreateDatasetExperiment($name: String!, $datasetId: String, $promptExperimentId: String, $params: Json) { + createDatasetExperiment(name: $name, datasetId: $datasetId, promptExperimentId: $promptExperimentId, params: $params) { id } } @@ -1853,7 +1912,7 @@ export class API { const datasetExperimentInput = { name: datasetExperiment.name, datasetId: datasetExperiment.datasetId, - promptId: datasetExperiment.promptId, + promptExperimentId: datasetExperiment.promptVariantId, params: datasetExperiment.params }; const result = await this.makeGqlCall(query, datasetExperimentInput); @@ -1947,6 +2006,34 @@ export class API { return result.data.createPromptLineage; } + /** + * Get an existing prompt lineage by name. + * + * @param name - The name of the prompt lineage. This parameter is required. + * @returns The existing prompt lineage object, or null. + */ + public async getPromptLineageByName(name: string) { + const query = `query promptLineage( + $name: String! + ) { + promptLineage( + name: $name + ) { + id + } + }`; + + const result = await this.makeGqlCall(query, { + name + }); + + if (!result.data || !result.data.promptLineage) { + return null; + } + + return result.data.promptLineage; + } + /** * @deprecated Please use getOrCreatePrompt instead. */ diff --git a/src/evaluation/dataset.ts b/src/evaluation/dataset.ts index 3dfcf96..1fa6edc 100644 --- a/src/evaluation/dataset.ts +++ b/src/evaluation/dataset.ts @@ -96,13 +96,13 @@ export class Dataset extends DatasetFields { */ async createExperiment(experiment: { name: string; - promptId?: string; + promptVariantId?: string; params?: Record | Array>; }) { const datasetExperiment = await this.api.createExperiment({ name: experiment.name, datasetId: this.id, - promptId: experiment.promptId, + promptVariantId: experiment.promptVariantId, params: experiment.params }); return new DatasetExperiment(this.api, datasetExperiment); diff --git a/tests/api.test.ts b/tests/api.test.ts index 3017d07..a23ac2d 100644 --- a/tests/api.test.ts +++ b/tests/api.test.ts @@ -448,6 +448,20 @@ describe('End to end tests for the SDK', function () { expect(experiment.id).not.toBeNull(); dataset.delete(); }); + + it('should create a dataset experiment with a prompt variant', async () => { + const promptVariantId = await client.api.createPromptVariant( + 'Default', + [{ role: 'user', content: 'hello' }], + { temperature: 0.5 } + ); + const experiment = await client.api.createExperiment({ + name: `test_${uuidv4()}`, + promptVariantId: promptVariantId + }); + expect(promptVariantId).toBeDefined(); + expect(experiment.id).not.toBeNull(); + }); }); describe('dataset item api', () => {