From 1105de879903e28b22944bc9131ab25f147c02c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Wed, 7 Aug 2024 17:51:14 +0200 Subject: [PATCH 1/2] feat: update sdk --- src/api.ts | 8 ++++++++ src/evaluation/experiment-item-run.ts | 7 ++++++- src/index.ts | 28 +++++++++++++++++++++++++++ src/observability/step.ts | 12 ++++++++++-- src/observability/thread.ts | 3 ++- tests/wrappers.test.ts | 27 ++++++++++++++++++++++++++ 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/api.ts b/src/api.ts index 5ca43de..b2e48b8 100644 --- a/src/api.ts +++ b/src/api.ts @@ -49,6 +49,7 @@ const version = packageJson.version; const stepFields = ` id threadId + rootRunId parentId startTime endTime @@ -152,6 +153,7 @@ function ingestStepsFieldsBuilder(steps: Step[]) { for (let id = 0; id < steps.length; id++) { generated += `$id_${id}: String! $threadId_${id}: String + $rootRunId_${id}: String $type_${id}: StepType $startTime_${id}: DateTime $endTime_${id}: DateTime @@ -177,6 +179,7 @@ function ingestStepsArgsBuilder(steps: Step[]) { step${id}: ingestStep( id: $id_${id} threadId: $threadId_${id} + rootRunId: $rootRunId_${id} startTime: $startTime_${id} endTime: $endTime_${id} type: $type_${id} @@ -1651,6 +1654,7 @@ export class API { input expectedOutput intermediarySteps + stepId } } `; @@ -1676,6 +1680,7 @@ export class API { input expectedOutput intermediarySteps + stepId } } `; @@ -1701,6 +1706,7 @@ export class API { input expectedOutput intermediarySteps + stepId } } `; @@ -1732,6 +1738,7 @@ export class API { input expectedOutput intermediarySteps + stepId } } `; @@ -1767,6 +1774,7 @@ export class API { input expectedOutput intermediarySteps + stepId } } `; diff --git a/src/evaluation/experiment-item-run.ts b/src/evaluation/experiment-item-run.ts index e7fbbfd..f71ab0f 100644 --- a/src/evaluation/experiment-item-run.ts +++ b/src/evaluation/experiment-item-run.ts @@ -39,7 +39,12 @@ export class ExperimentItemRun extends Step { { currentThread: currentStore?.currentThread ?? null, currentStep: this, - currentExperimentItemRunId: this.id ?? null + currentExperimentItemRunId: this.id ?? null, + rootRun: currentStore?.rootRun + ? currentStore?.rootRun + : this.type === 'run' + ? this + : null }, async () => { try { diff --git a/src/index.ts b/src/index.ts index b4ae30a..d1f8430 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,6 +17,7 @@ type StoredContext = { currentThread: Thread | null; currentStep: Step | null; currentExperimentItemRunId?: string | null; + rootRun: Step | null; }; const storage = new AsyncLocalStorage(); @@ -125,6 +126,16 @@ export class LiteralClient { return store?.currentExperimentItemRunId || null; } + /** + * Returns the root run from the context or null if none. + * @returns The root run, if any. + */ + _rootRun(): Step | null { + const store = storage.getStore(); + + return store?.rootRun || null; + } + /** * Gets the current thread from the context. * WARNING : this will throw if run outside of a thread context. @@ -175,4 +186,21 @@ export class LiteralClient { return store?.currentExperimentItemRunId; } + + /** + * Gets the root run from the context. + * WARNING : this will throw if run outside of a step context. + * @returns The current step, if any. + */ + getRootRun(): Step { + const store = storage.getStore(); + + if (!store?.rootRun) { + throw new Error( + 'Literal AI SDK : tried to access root run outside of a context.' + ); + } + + return store.rootRun; + } } diff --git a/src/observability/step.ts b/src/observability/step.ts index 4dc2d95..ff5e4f8 100644 --- a/src/observability/step.ts +++ b/src/observability/step.ts @@ -23,6 +23,7 @@ class StepFields extends Utils { name!: string; type!: StepType; threadId?: string; + rootRunId?: Maybe; createdAt?: Maybe; startTime?: Maybe; id?: Maybe; @@ -73,9 +74,10 @@ export class Step extends StepFields { return; } - // Automatically assign parent thread & step if there are any in the store. + // Automatically assign parent thread & step & rootRun if there are any in the store. this.threadId = this.threadId ?? this.client._currentThread()?.id; this.parentId = this.parentId ?? this.client._currentStep()?.id; + this.rootRunId = this.rootRunId ?? this.client._rootRun()?.id; // Set the creation and start time to the current time if not provided. if (!this.createdAt) { @@ -167,7 +169,12 @@ export class Step extends StepFields { currentThread: currentStore?.currentThread ?? null, currentExperimentItemRunId: currentStore?.currentExperimentItemRunId ?? null, - currentStep: this + currentStep: this, + rootRun: currentStore?.rootRun + ? currentStore?.rootRun + : this.type === 'run' + ? this + : null }, () => cb(this) ); @@ -197,6 +204,7 @@ export class Step extends StepFields { this.scores = updatedStep.scores ?? this.scores; this.attachments = updatedStep.attachments ?? this.attachments; this.environment = updatedStep.environment ?? this.environment; + this.rootRunId = updatedStep.rootRunId ?? this.rootRunId; } this.send().catch(console.error); diff --git a/src/observability/thread.ts b/src/observability/thread.ts index 51b37bd..459481c 100644 --- a/src/observability/thread.ts +++ b/src/observability/thread.ts @@ -109,7 +109,8 @@ export class Thread extends ThreadFields { currentThread: this, currentExperimentItemRunId: currentStore?.currentExperimentItemRunId ?? null, - currentStep: null + currentStep: null, + rootRun: null }, () => cb(this) ); diff --git a/tests/wrappers.test.ts b/tests/wrappers.test.ts index 7a961df..2666b6d 100644 --- a/tests/wrappers.test.ts +++ b/tests/wrappers.test.ts @@ -143,6 +143,32 @@ describe('Wrapper', () => { }); }); + it('handles nested runs', async () => { + let runId: Maybe; + let stepId: Maybe; + + const step = async (_query: string) => + client.step({ name: 'foo', type: 'undefined' }).wrap(async () => { + stepId = client.getCurrentStep()!.id; + }); + + await client.thread({ name: 'Test Wrappers Thread' }).wrap(async () => { + return client.run({ name: 'Test Wrappers Run' }).wrap(async () => { + runId = client.getCurrentStep()!.id; + + return client.run({ name: 'Test Nested Run' }).wrap(async () => { + await step('foo'); + }); + }); + }); + + await sleep(1000); + const run = await client.api.getStep(runId!); + const retrieveStep = await client.api.getStep(stepId!); + + expect(retrieveStep!.rootRunId).toEqual(run!.id); + }); + it('handles steps outside of a thread', async () => { let runId: Maybe; let stepId: Maybe; @@ -172,6 +198,7 @@ describe('Wrapper', () => { expect(step!.name).toEqual('Test Wrappers Step'); expect(step!.threadId).toBeNull(); expect(step!.parentId).toEqual(run!.id); + expect(step!.rootRunId).toEqual(run!.id); }); it("doesn't leak the current store when getting entities from the API", async () => { From f73840ac153aac45c62175beef6975f4dab92b66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Wed, 7 Aug 2024 17:59:11 +0200 Subject: [PATCH 2/2] refactor: rename value --- tests/wrappers.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/wrappers.test.ts b/tests/wrappers.test.ts index 2666b6d..a578f3c 100644 --- a/tests/wrappers.test.ts +++ b/tests/wrappers.test.ts @@ -164,9 +164,9 @@ describe('Wrapper', () => { await sleep(1000); const run = await client.api.getStep(runId!); - const retrieveStep = await client.api.getStep(stepId!); + const createdStep = await client.api.getStep(stepId!); - expect(retrieveStep!.rootRunId).toEqual(run!.id); + expect(createdStep!.rootRunId).toEqual(run!.id); }); it('handles steps outside of a thread', async () => {