Skip to content

WIP: feat: create very simple wrappers #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 10, 2024
Merged
21 changes: 21 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@ import { API } from './api';
import instrumentation from './instrumentation';
import openai from './openai';
import { Step, StepConstructor, Thread, ThreadConstructor } from './types';
import {
StepWrapperOptions,
ThreadWrapperOptions,
wrapInStep,
wrapInThread
} from './wrappers';

export * from './types';
export * from './generation';
export type * from './wrappers';

export type * from './instrumentation';

Expand Down Expand Up @@ -39,4 +46,18 @@ export class LiteralClient {
const runData = { ...data, type: 'run' as const };
return new Step(this.api, runData);
}

wrapInStep<TArgs extends unknown[], TReturn>(
fn: (...args: TArgs) => Promise<TReturn>,
options: StepWrapperOptions
) {
return wrapInStep(this, fn, options);
}

wrapInThread<TArgs extends unknown[], TReturn>(
fn: (...args: TArgs) => Promise<TReturn>,
options: ThreadWrapperOptions<TArgs>
) {
return wrapInThread(this, fn, options);
}
}
70 changes: 70 additions & 0 deletions src/wrappers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { AsyncLocalStorage } from 'node:async_hooks';

import type {
LiteralClient,
Step,
StepConstructor,
ThreadConstructor
} from './index';

const storage = new AsyncLocalStorage<Step>();

export type StepWrapperOptions = {
step: StepConstructor;
};

export const wrapInStep =
<TArgs extends unknown[], TReturn>(
client: LiteralClient,
fn: (...args: TArgs) => Promise<TReturn>,
options: StepWrapperOptions
) =>
async (...args: TArgs): Promise<TReturn> => {
const parentStep = storage.getStore();
const step = parentStep
? parentStep.step(options.step)
: client.step({ ...options.step, type: 'run' });

const startTime = new Date();
const result = await storage.run(step, () => fn(...args));

step.input = { inputs: args };
step.output = { output: result };
step.startTime = startTime.toISOString();
step.endTime = new Date().toISOString();
await step.send();

return result;
};

export type ThreadWrapperOptions<TArgs extends unknown[]> = {
thread: Omit<ThreadConstructor, 'id'> & {
id?: (...args: TArgs) => string;
};
run: StepConstructor;
};

export const wrapInThread =
<TArgs extends unknown[], TReturn>(
client: LiteralClient,
fn: (...args: TArgs) => Promise<TReturn>,
options: ThreadWrapperOptions<TArgs>
) =>
async (...args: TArgs): Promise<TReturn> => {
const { id, ...threadOptions } = options.thread;
const thread = await client
.thread({ ...threadOptions, id: id?.(...args) })
.upsert();
const runStep = thread.step(options.run);

const startTime = new Date();
const result = await storage.run(runStep, () => fn(...args));

runStep.input = { inputs: args };
runStep.output = { output: result };
runStep.startTime = startTime.toISOString();
runStep.endTime = new Date().toISOString();
await runStep.send();

return result;
};
72 changes: 72 additions & 0 deletions tests/integration/wrappers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import 'dotenv/config';

import { LiteralClient } from '../../src';

const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

describe('Wrapper', () => {
let client: LiteralClient;

beforeAll(function () {
const url = process.env.LITERAL_API_URL;
const apiKey = process.env.LITERAL_API_KEY;

if (!url || !apiKey) {
throw new Error('Missing environment variables');
}

client = new LiteralClient(apiKey, url);
});

it.only('should handle simple use case', async () => {
const retrieve = client.wrapInStep(
async (_query: string) => {
await sleep(1000);
return [
{ score: 0.8, text: 'France is a country in Europe' },
{ score: 0.7, text: 'Paris is the capital of France' }
];
},
{
step: {
name: 'Retrieve',
type: 'retrieval'
}
}
);

const completion = client.wrapInStep(
async (_query: string, _augmentations: string[]) => {
await sleep(1000);
return { content: 'Paris is a city in Europe' };
},
{
step: {
name: 'Completion',
type: 'llm'
}
}
);

const main = client.wrapInThread(
async (query: string) => {
const results = await retrieve(query);
const augmentations = results.map((result) => result.text);
const completionText = await completion(query, augmentations);
return completionText.content;
},
{
thread: {
name: 'Test Wrappers'
},
run: {
name: 'Run',
type: 'run'
}
}
);

const result = await main('France');
expect(result).toBe('Paris is a city in Europe');
});
});
Loading