Skip to content

WIP: feat: create very simple wrappers #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import FormData from 'form-data';
import { createReadStream } from 'fs';
import { v4 as uuidv4 } from 'uuid';

import { LiteralClient } from '.';
import {
GenerationsFilter,
GenerationsOrderBy,
Expand Down Expand Up @@ -327,6 +328,8 @@ function addGenerationsToDatasetQueryBuilder(generationIds: string[]) {
}

export class API {
/** @ignore */
private client: LiteralClient;
/** @ignore */
private apiKey: string;
/** @ignore */
Expand All @@ -339,7 +342,13 @@ export class API {
public disabled: boolean;

/** @ignore */
constructor(apiKey: string, url: string, disabled?: boolean) {
constructor(
client: LiteralClient,
apiKey: string,
url: string,
disabled?: boolean
) {
this.client = client;
this.apiKey = apiKey;
this.url = url;
this.graphqlEndpoint = `${url}/api/graphql`;
Expand Down Expand Up @@ -509,7 +518,9 @@ export class API {

const response = result.data.steps;

response.data = response.edges.map((x: any) => new Step(this, x.node));
response.data = response.edges.map(
(x: any) => new Step(this.client, x.node, true)
);
delete response.edges;

return response;
Expand Down Expand Up @@ -541,7 +552,7 @@ export class API {
return null;
}

return new Step(this, result.data.step);
return new Step(this.client, result.data.step, true);
}

/**
Expand Down Expand Up @@ -878,7 +889,7 @@ export class API {
};

const response = await this.makeGqlCall(query, variables);
return new Thread(this, response.data.upsertThread);
return new Thread(this.client, response.data.upsertThread);
}

/**
Expand Down Expand Up @@ -943,7 +954,9 @@ export class API {

const response = result.data.threads;

response.data = response.edges.map((x: any) => new Thread(this, x.node));
response.data = response.edges.map(
(x: any) => new Thread(this.client, x.node)
);
delete response.edges;

return response;
Expand All @@ -967,12 +980,11 @@ export class API {
const variables = { id };

const response = await this.makeGqlCall(query, variables);

if (!response.data.threadDetail) {
return null;
}

return new Thread(this, response.data.threadDetail);
return new Thread(this.client, response.data.threadDetail);
}

/**
Expand Down
35 changes: 29 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import { AsyncLocalStorage } from 'node:async_hooks';

import { API } from './api';
import instrumentation from './instrumentation';
import openai from './openai';
import { Step, StepConstructor, Thread, ThreadConstructor } from './types';
import {
Maybe,
Step,
StepConstructor,
Thread,
ThreadConstructor
} from './types';

export * from './types';
export * from './generation';

export type * from './instrumentation';

type StoredContext = {
currentThread: Thread | null;
currentStep: Step | null;
};

const storage = new AsyncLocalStorage<StoredContext>();

export class LiteralClient {
api: API;
openai: ReturnType<typeof openai>;
instrumentation: ReturnType<typeof instrumentation>;
store: AsyncLocalStorage<StoredContext> = storage;

constructor(apiKey?: string, apiUrl?: string, disabled?: boolean) {
if (!apiKey) {
Expand All @@ -22,21 +38,28 @@ export class LiteralClient {
apiUrl = process.env.LITERAL_API_URL || 'https://cloud.getliteral.ai';
}

this.api = new API(apiKey!, apiUrl!, disabled);
this.api = new API(this, apiKey!, apiUrl!, disabled);
this.openai = openai(this);
this.instrumentation = instrumentation(this);
}

thread(data?: ThreadConstructor) {
return new Thread(this.api, data);
return new Thread(this, data);
}

step(data: StepConstructor) {
return new Step(this.api, data);
return new Step(this, data);
}

run(data: Omit<StepConstructor, 'type'>) {
const runData = { ...data, type: 'run' as const };
return new Step(this.api, runData);
return this.step({ ...data, type: 'run' });
}

getCurrentThread(): Maybe<Thread> {
return storage.getStore()?.currentThread ?? null;
}

getCurrentStep(): Maybe<Step> {
return storage.getStore()?.currentStep ?? null;
}
}
106 changes: 99 additions & 7 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
} from 'openai/resources';
import { v4 as uuidv4 } from 'uuid';

import { LiteralClient } from '.';
import { API } from './api';
import { Generation, GenerationType, IGenerationMessage } from './generation';
import { CustomChatPromptTemplate } from './instrumentation/langchain';
Expand Down Expand Up @@ -37,7 +38,7 @@ export class Utils {
serialize(): any {
const dict: any = {};
Object.keys(this as any).forEach((key) => {
if (key === 'api') {
if (['api', 'client'].includes(key)) {
return;
}
if ((this as any)[key] !== undefined) {
Expand Down Expand Up @@ -135,20 +136,24 @@ export type ThreadConstructor = Omit<CleanThreadFields, 'id'> &
*/
export class Thread extends ThreadFields {
api: API;
client: LiteralClient;

/**
* Constructs a new Thread instance.
* @param api - The API instance to interact with backend services.
* @param data - Optional initial data for the thread, with an auto-generated ID if not provided.
*/
constructor(api: API, data?: ThreadConstructor) {
constructor(client: LiteralClient, data?: ThreadConstructor) {
super();
this.api = api;
this.api = client.api;
this.client = client;

if (!data) {
data = { id: uuidv4() };
} else if (!data.id) {
data.id = uuidv4();
}

Object.assign(this, data);
}

Expand All @@ -158,12 +163,21 @@ export class Thread extends ThreadFields {
* @returns A new Step instance linked to this thread.
*/
step(data: Omit<StepConstructor, 'threadId'>) {
return new Step(this.api, {
return new Step(this.client, {
...data,
threadId: this.id
});
}

/**
* Creates a new Run step associated with this thread.
* @param data - The data for the new step, excluding the thread ID and the type
* @returns A new Step instance linked to this thread.
*/
run(data: Omit<StepConstructor, 'threadId' | 'type'>) {
return this.step({ ...data, type: 'run' });
}

/**
* Upserts the thread data to the backend, creating or updating as necessary.
* @returns The updated Thread instance.
Expand All @@ -182,6 +196,31 @@ export class Thread extends ThreadFields {
});
return this;
}

async wrap<Output>(
cb: (thread: Thread) => Output | Promise<Output>,
updateThread?:
| ThreadConstructor
| ((output: Output) => ThreadConstructor)
| ((output: Output) => Promise<ThreadConstructor>)
) {
const output = await this.client.store.run(
{ currentThread: this, currentStep: null },
() => cb(this)
);

if (updateThread) {
const updatedThread =
typeof updateThread === 'function'
? await updateThread(output)
: updateThread;
Object.assign(this, updatedThread);
}

await this.upsert();

return output;
}
}

export type StepType =
Expand Down Expand Up @@ -222,22 +261,43 @@ export type StepConstructor = OmitUtils<StepFields>;
*/
export class Step extends StepFields {
api: API;
client: LiteralClient;

/**
* Constructs a new Step instance.
* @param api The API instance to be used for sending and managing steps.
* @param data The initial data for the step, excluding utility properties.
*/
constructor(api: API, data: StepConstructor) {
constructor(
client: LiteralClient,
data: StepConstructor,
ignoreContext?: true
) {
super();
this.api = api;
this.api = client.api;
this.client = client;

Object.assign(this, data);

// Automatically generate an ID if not provided.
if (!this.id) {
this.id = uuidv4();
}

if (ignoreContext) {
return;
}

// Automatically assign parent thread & step if there are any in the store.
const store = this.client.store.getStore();

if (store?.currentThread) {
this.threadId = store.currentThread.id;
}
if (store?.currentStep) {
this.parentId = store.currentStep.id;
}

// Set the creation and start time to the current time if not provided.
if (!this.createdAt) {
this.createdAt = new Date().toISOString();
Expand Down Expand Up @@ -284,7 +344,7 @@ export class Step extends StepFields {
* @returns A new Step instance.
*/
step(data: Omit<StepConstructor, 'threadId'>) {
return new Step(this.api, {
return new Step(this.client, {
...data,
threadId: this.threadId,
parentId: this.id
Expand All @@ -306,6 +366,38 @@ export class Step extends StepFields {
await this.api.sendSteps([this]);
return this;
}

async wrap<Output>(
cb: (step: Step) => Output | Promise<Output>,
updateStep?:
| Partial<StepConstructor>
| ((output: Output) => Partial<StepConstructor>)
| ((output: Output) => Promise<Partial<StepConstructor>>)
) {
const startTime = new Date();
this.startTime = startTime.toISOString();
const currentStore = this.client.store.getStore();

const output = await this.client.store.run(
{ currentThread: currentStore?.currentThread ?? null, currentStep: this },
() => cb(this)
);

this.output = { output };
this.endTime = new Date().toISOString();

if (updateStep) {
const updatedStep =
typeof updateStep === 'function'
? await updateStep(output)
: updateStep;
Object.assign(this, updatedStep);
}

await this.send();

return output;
}
}

/**
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/api.test.ts → tests/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
Dataset,
LiteralClient,
Score
} from '../../src';
} from '../src';

describe('End to end tests for the SDK', function () {
let client: LiteralClient;
Expand Down
12 changes: 12 additions & 0 deletions tests/async.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { AsyncLocalStorage } from 'node:async_hooks';

const storage = new AsyncLocalStorage<string>();

describe('Async Local Storage', () => {
it('is supported on this environment', () => {
storage.run('This is good', async () => {
const store = await storage.getStore();
expect(store).toEqual('This is good');
});
});
});
File renamed without changes
Loading
Loading