Skip to content

feat(instrumentation): make the openai instrumentation context aware #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 50 additions & 58 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ export class API {

return response.data;
} catch (e) {
console.error(e);
if (e instanceof AxiosError) {
throw new Error(JSON.stringify(e.response?.data.errors));
} else {
Expand Down Expand Up @@ -426,6 +427,7 @@ export class API {

return response.data;
} catch (e) {
console.error(e);
if (e instanceof AxiosError) {
throw new Error(JSON.stringify(e.response?.data));
} else {
Expand Down Expand Up @@ -696,70 +698,60 @@ export class API {
orderBy?: GenerationsOrderBy;
}): Promise<PaginatedResponse<PersistedGeneration>> {
const query = `
query GetGenerations(
$after: ID,
$before: ID,
$cursorAnchor: DateTime,
$filters: [generationsInputType!],
$orderBy: GenerationsOrderByInput,
$first: Int,
$last: Int,
$projectId: String,
query GetGenerations(
$after: ID
$before: ID
$cursorAnchor: DateTime
$filters: [generationsInputType!]
$orderBy: GenerationsOrderByInput
$first: Int
$last: Int
$projectId: String
) {
generations(
after: $after,
before: $before,
cursorAnchor: $cursorAnchor,
filters: $filters,
orderBy: $orderBy,
first: $first,
last: $last,
projectId: $projectId,
) {
generations(
after: $after
before: $before
cursorAnchor: $cursorAnchor
filters: $filters
orderBy: $orderBy
first: $first
last: $last
projectId: $projectId
) {
pageInfo {
startCursor
endCursor
hasNextPage
hasPreviousPage
startCursor
endCursor
hasNextPage
hasPreviousPage
}
totalCount
edges {
cursor
node {
id
projectId
prompt
completion
createdAt
provider
model
variables
messages
messageCompletion
tools
settings
stepId
tokenCount
duration
inputTokenCount
outputTokenCount
ttFirstToken
duration
tokenThroughputInSeconds
error
type
tags
step {
threadId
thread {
participant {
identifier
}
}
}
}
}
cursor
node {
id
projectId
prompt
completion
createdAt
provider
model
variables
messages
messageCompletion
tools
settings
tokenCount
duration
inputTokenCount
outputTokenCount
ttFirstToken
tokenThroughputInSeconds
error
type
tags
}
}
}
}`;

const result = await this.makeGqlCall(query, variables);
Expand Down
15 changes: 5 additions & 10 deletions src/instrumentation/index.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import { LiteralClient, Step, Thread } from '..';
import { LiteralClient } from '..';
import { LiteralCallbackHandler } from './langchain';
import { instrumentLlamaIndex, withThread } from './llamaindex';
import instrumentOpenAI, {
InstrumentOpenAIOptions,
OpenAIOutput
} from './openai';
import instrumentOpenAI from './openai';
import { InstrumentOpenAIOptions } from './openai';
import { makeInstrumentVercelSDK } from './vercel-sdk';

export type { InstrumentOpenAIOptions } from './openai';

export default (client: LiteralClient) => ({
openai: (
output: OpenAIOutput,
parent?: Step | Thread,
options?: InstrumentOpenAIOptions
) => instrumentOpenAI(client, output, parent, options),
openai: (options?: InstrumentOpenAIOptions) =>
instrumentOpenAI(client, options),
langchain: {
literalCallback: (threadId?: string) => {
try {
Expand Down
156 changes: 85 additions & 71 deletions src/instrumentation/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,14 @@ import {
Thread
} from '..';

const openaiReqs: Record<
string,
{
// Record the ID of the request
id: string;
// Record the start time of the request
start: number;
// Record the inputs of the request
inputs: Record<string, any>;
// Record the stream of the request if it's a streaming request
stream?: Stream<ChatCompletionChunk | Completion>;
}
> = {};

// Define a generic type for the original function to be wrapped
type OriginalFunction<T extends any[], R> = (...args: T) => Promise<R>;

// Utility function to wrap a method
function wrapFunction<T extends any[], R>(
originalFunction: OriginalFunction<T, R>
originalFunction: OriginalFunction<T, R>,
client: LiteralClient,
options: InstrumentOpenAIOptions = {}
): OriginalFunction<T, R> {
return async function (this: any, ...args: T): Promise<R> {
const start = Date.now();
Expand All @@ -46,58 +34,57 @@ function wrapFunction<T extends any[], R>(
const result = await originalFunction.apply(this, args);

if (result instanceof Stream) {
const streamResult = result as Stream<ChatCompletionChunk | Completion>;
// If it is a streaming request, we need to process the first token to get the id
// However we also need to tee the stream so that the end developer can process the stream
const [a, b] = streamResult.tee();
// Re split the stream to store a clean instance for final processing later on
const c = a.tee()[0];
let id;
// Iterate over the stream to find the first chunk and store the id
for await (const chunk of a) {
id = chunk.id;
if (!openaiReqs[id]) {
openaiReqs[id] = {
id,
inputs: args[0],
start,
stream: c
};
break;
}
}
// @ts-expect-error Hacky way to add the id to the stream
b.id = id;
const streamResult = result;
const [returnedResult, processedResult] = streamResult.tee();

await processOpenAIOutput(client, processedResult, {
...options,
start,
inputs: args[0]
});

return b as any;
return returnedResult as R;
} else {
const regularResult = result as ChatCompletion | Completion;
const id = regularResult.id;
openaiReqs[id] = {
id,
inputs: args[0],
start
};
await processOpenAIOutput(client, result as ChatCompletion | Completion, {
...options,
start,
inputs: args[0]
});

return result;
}
};
}

// Patching the chat.completions.create function
const originalChatCompletionsCreate = OpenAI.Chat.Completions.prototype.create;
OpenAI.Chat.Completions.prototype.create = wrapFunction(
originalChatCompletionsCreate
) as any;

// Patching the completions.create function
const originalCompletionsCreate = OpenAI.Completions.prototype.create;
OpenAI.Completions.prototype.create = wrapFunction(
originalCompletionsCreate
) as any;

// Patching the completions.create function
const originalImagesGenerate = OpenAI.Images.prototype.generate;
OpenAI.Images.prototype.generate = wrapFunction(originalImagesGenerate) as any;
function instrumentOpenAI(
client: LiteralClient,
options: InstrumentOpenAIOptions = {}
) {
// Patching the chat.completions.create function
const originalChatCompletionsCreate =
OpenAI.Chat.Completions.prototype.create;
OpenAI.Chat.Completions.prototype.create = wrapFunction(
originalChatCompletionsCreate,
client,
options
) as any;

// Patching the completions.create function
const originalCompletionsCreate = OpenAI.Completions.prototype.create;
OpenAI.Completions.prototype.create = wrapFunction(
originalCompletionsCreate,
client,
options
) as any;

// Patching the images.generate function
const originalImagesGenerate = OpenAI.Images.prototype.generate;
OpenAI.Images.prototype.generate = wrapFunction(
originalImagesGenerate,
client,
options
) as any;
}

function processChatDelta(
newDelta: ChatCompletionChunk.Choice.Delta,
Expand Down Expand Up @@ -296,22 +283,49 @@ export interface InstrumentOpenAIOptions {
tags?: Maybe<string[]>;
}

const instrumentOpenAI = async (
export interface ProcessOpenAIOutput extends InstrumentOpenAIOptions {
start: number;
inputs: Record<string, any>;
}

function isStream(obj: any): boolean {
return (
obj !== null &&
typeof obj === 'object' &&
typeof obj.pipe === 'function' &&
typeof obj.on === 'function' &&
typeof obj.read === 'function'
);
}

const processOpenAIOutput = async (
client: LiteralClient,
output: OpenAIOutput,
parent?: Step | Thread,
options: InstrumentOpenAIOptions = {}
{ start, tags, inputs }: ProcessOpenAIOutput
) => {
//@ts-expect-error - This is a hacky way to get the id from the stream
const outputId = output.id;
const { stream, start, inputs } = openaiReqs[outputId];
const baseGeneration = {
provider: 'openai',
model: inputs.model,
settings: getSettings(inputs),
tags: options.tags
tags: tags
};

let threadFromStore: Thread | null = null;
try {
threadFromStore = client.getCurrentThread();
} catch (error) {
// Ignore error thrown if getCurrentThread is called outside of a context
}

let stepFromStore: Step | null = null;
try {
stepFromStore = client.getCurrentStep();
} catch (error) {
// Ignore error thrown if getCurrentStep is called outside of a context
}

const parent = stepFromStore || threadFromStore;

if ('data' in output) {
// Image Generation

Expand All @@ -322,14 +336,16 @@ const instrumentOpenAI = async (
output: output,
startTime: new Date(start).toISOString(),
endTime: new Date().toISOString(),
tags: options.tags
tags: tags
};

const step = parent
? parent.step(stepData)
: client.step({ ...stepData, type: 'run' });
await step.send();
} else if (output instanceof Stream) {
} else if (output instanceof Stream || isStream(output)) {
const stream = output as Stream<ChatCompletionChunk | Completion>;

if (!stream) {
throw new Error('Stream not found');
}
Expand Down Expand Up @@ -460,8 +476,6 @@ const instrumentOpenAI = async (
}
}
}

delete openaiReqs[outputId];
};

export default instrumentOpenAI;
Loading
Loading