Skip to content

feat: implement cache on getPrompt and getPromptById methods #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 29, 2024
175 changes: 125 additions & 50 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,63 @@ type CreateAttachmentParams = {
metadata?: Maybe<Record<string, any>>;
};

export class SharedCache {
private static instance: SharedCache | null = null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify even more:

private cache: Map<string, any>;

private constructor() {
this.cache = new Map();
}

static getInstance(): SharedCache {
if (!SharedCache.instance) {
SharedCache.instance = new SharedCache();
}
return SharedCache.instance;
}

public getPromptCacheKey(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the cache is not prompt only, let's not leak that business logic in the cache layer.

id?: string,
name?: string,
version?: number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When handling only optional params in js it's recommended to use an object because it's not great ux to do:

instance.getPromptCacheKey(undefined, undefined, version)
// vs
instance.getPromptCacheKey({ version })

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now I understand why auto-complete suggested a object thx

): string {
if (id) {
return id;
} else if (name && (version || version === 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is actually a way to do that explicitly:

Suggested change
} else if (name && (version || version === 0)) {
} else if (name && typeof version === "number") {

return `${name}:${version}`;
} else if (name) {
return name;
}
throw new Error('Either id or name must be provided');
}

public getPrompt(key: string): Prompt {
return this.get(key);
}

public putPrompt(prompt: Prompt): void {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

this.put(prompt.id, prompt);
this.put(prompt.name, prompt);
this.put(`${prompt.name}:${prompt.version}`, prompt);
}

public getCache(): Map<string, any> {
return this.cache;
}

public get(key: string): any {
return this.cache.get(key);
}

public put(key: string, value: any): void {
this.cache.set(key, value);
}

public clear(): void {
this.cache.clear();
}
}

/**
* Represents the API client for interacting with the Literal service.
* This class handles API requests, authentication, and provides methods
Expand All @@ -340,6 +397,8 @@ type CreateAttachmentParams = {
* Then you can use the `api` object to make calls to the Literal service.
*/
export class API {
/** @ignore */
public cache: SharedCache;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used it to simplify my tests implementation and forgot to change it later

/** @ignore */
public client: LiteralClient;
/** @ignore */
Expand Down Expand Up @@ -372,6 +431,8 @@ export class API {
throw new Error('LITERAL_API_URL not set');
}

this.cache = SharedCache.getInstance();

this.apiKey = apiKey;
this.url = url;
this.environment = environment;
Expand Down Expand Up @@ -399,7 +460,7 @@ export class API {
* @returns The data part of the response from the GraphQL endpoint.
* @throws Will throw an error if the GraphQL call returns errors or if the request fails.
*/
private async makeGqlCall(query: string, variables: any) {
private async makeGqlCall(query: string, variables: any, timeout?: number) {
try {
const response = await axios({
url: this.graphqlEndpoint,
Expand All @@ -408,7 +469,8 @@ export class API {
data: {
query: query,
variables: variables
}
},
timeout
});
if (response.data.errors) {
throw new Error(JSON.stringify(response.data.errors));
Expand Down Expand Up @@ -2110,41 +2172,76 @@ export class API {
}

/**
* Retrieves a prompt by its id.
*
* @param id ID of the prompt to retrieve.
* @returns The prompt with given ID.
* Retrieves a prompt by its id. If the request fails, it will try to get the prompt from the cache.
*/
public async getPromptById(id: string) {
const query = `
query GetPrompt($id: String!) {
promptVersion(id: $id) {
createdAt
id
label
settings
status
tags
templateMessages
tools
type
updatedAt
url
variables
variablesDefaultValues
version
lineage {
name
query GetPrompt($id: String!) {
promptVersion(id: $id) {
createdAt
id
label
settings
status
tags
templateMessages
tools
type
updatedAt
url
variables
variablesDefaultValues
version
lineage {
name
}
}
}
}
`;

return await this.getPromptWithQuery(query, { id });
return this.getPromptWithQuery(query, { id });
}

/**
* Private helper method to execute prompt queries with error handling and caching
*/
private async getPromptWithQuery(
query: string,
variables: Record<string, any>
) {
const { id, name, version } = variables;
const cachedPrompt = this.cache.getPrompt(
this.cache.getPromptCacheKey(id, name, version)
);
const timeout = cachedPrompt ? 1000 : undefined;

try {
const result = await this.makeGqlCall(query, variables, timeout);

if (!result.data || !result.data.promptVersion) {
return cachedPrompt;
}

const promptData = result.data.promptVersion;
promptData.provider = promptData.settings?.provider;
promptData.name = promptData.lineage?.name;
delete promptData.lineage;
if (promptData.settings) {
delete promptData.settings.provider;
}

const prompt = new Prompt(this, promptData);
this.cache.putPrompt(prompt);
return prompt;
} catch (error) {
console.log('key: ', this.cache.getPromptCacheKey(id, name, version));
console.log('cachedPrompt: ', cachedPrompt);
return cachedPrompt;
}
}

/**
* Retrieves a prompt by its name and optionally by its version.
* Retrieves a prompt by its name and optionally by its version. If the request fails, it will try to get the prompt from the cache.
*
* @param name - The name of the prompt to retrieve.
* @param version - The version number of the prompt (optional).
Expand All @@ -2171,29 +2268,7 @@ export class API {
}
}
`;

return await this.getPromptWithQuery(query, { name, version });
}

private async getPromptWithQuery(
query: string,
variables: Record<string, any>
) {
const result = await this.makeGqlCall(query, variables);

if (!result.data || !result.data.promptVersion) {
return null;
}

const promptData = result.data.promptVersion;
promptData.provider = promptData.settings?.provider;
promptData.name = promptData.lineage?.name;
delete promptData.lineage;
if (promptData.settings) {
delete promptData.settings.provider;
}

return new Prompt(this, promptData);
return this.getPromptWithQuery(query, { name, version });
}

/**
Expand Down
70 changes: 70 additions & 0 deletions tests/api.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import axios from 'axios';
import 'dotenv/config';
import { v4 as uuidv4 } from 'uuid';

import { ChatGeneration, IGenerationMessage, LiteralClient } from '../src';
import { Dataset } from '../src/evaluation/dataset';
import { Score } from '../src/evaluation/score';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';
import { sleep } from './utils';

describe('End to end tests for the SDK', function () {
Expand Down Expand Up @@ -597,6 +599,30 @@ describe('End to end tests for the SDK', function () {
});

describe('Prompt api', () => {
const mockPromptData: PromptConstructor = {
id: 'test-id',
name: 'test-prompt',
version: 1,
createdAt: new Date().toISOString(),
type: 'CHAT',
templateMessages: [],
tools: [],
settings: {
provider: 'test',
model: 'test',
frequency_penalty: 0,
presence_penalty: 0,
temperature: 0,
top_p: 0,
max_tokens: 0
},
variables: [],
variablesDefaultValues: {},
metadata: {},
items: [],
provider: 'test'
};

it('should get a prompt by name', async () => {
const prompt = await client.api.getPrompt('Default');

Expand Down Expand Up @@ -657,6 +683,50 @@ is a templated list.`;
expect(formatted[0].content).toBe(expected);
});

it('should fallback to cache when getPromptById DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
client.api.cache.putPrompt(prompt);

jest
.spyOn(client.api as any, 'makeGqlCall')
.mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPromptById(prompt.id);
expect(result).toEqual(prompt);
});

it('should fallback to cache when getPrompt DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
client.api.cache.putPrompt(prompt);
jest.spyOn(axios, 'post').mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPrompt(prompt.id);
expect(result).toEqual(prompt);
});

it('should update cache with fresh data on successful DB call', async () => {
const prompt = new Prompt(client.api, mockPromptData);

jest.spyOn(client.api as any, 'makeGqlCall').mockResolvedValueOnce({
data: { promptVersion: prompt }
});

await client.api.getPromptById(prompt.id);

const cachedPrompt = await client.api.cache.get(prompt.id);
expect(cachedPrompt).toBeDefined();
expect(cachedPrompt?.id).toBe(prompt.id);
});

it('should return null when both DB and cache fail', async () => {
jest
.spyOn(client.api as any, 'makeGqlCall')
.mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPromptById('non-existent-id');
expect(result).toBeUndefined();
});

it('should get a prompt A/B testing configuration', async () => {
const promptName = 'TypeScript SDK E2E Tests';

Expand Down
Loading