Skip to content

Commit e859550

Browse files
Add tracing to model provider and embedding provider
1 parent 3e6ebed commit e859550

File tree

3 files changed

+171
-33
lines changed

3 files changed

+171
-33
lines changed

ballerina/embedding_provider.bal

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// under the License.
1616

1717
import ballerina/ai;
18+
import ballerina/ai.observe;
1819
import ballerinax/azure.openai.embeddings;
1920

2021
# EmbeddingProvider provides an interface for interacting with Azure OpenAI Embedding Models.
@@ -76,19 +77,38 @@ public distinct isolated client class EmbeddingProvider {
7677
# + chunk - The `ai:Chunk` containing the content to embed
7778
# + return - The resulting `ai:Embedding` on success; otherwise, returns an `ai:Error`
7879
isolated remote function embed(ai:Chunk chunk) returns ai:Embedding|ai:Error {
80+
observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId);
81+
span.addProvider("azure.ai.openai");
82+
7983
if chunk !is ai:TextDocument|ai:TextChunk {
80-
return error ai:Error("Unsupported document type. only 'ai:TextDocument' or 'ai:TextChunk' is supported");
84+
ai:Error err = error ai:Error("Unsupported chunk type. only 'ai:TextDocument|ai:TextChunk' is supported");
85+
span.close(err);
86+
return err;
8187
}
88+
8289
do {
90+
span.addInputContent(chunk.content);
8391
embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post(
8492
apiVersion = self.apiVersion,
8593
payload = {
8694
input: chunk.content
8795
}
8896
);
89-
return check response.data[0].embedding.cloneWithType();
97+
98+
span.addInputTokenCount(response.usage.prompt_tokens);
99+
if response.data.length() == 0 {
100+
ai:Error err = error("No embeddings generated for the provided chunk");
101+
span.close(err);
102+
return err;
103+
}
104+
105+
ai:Embedding embedding = check response.data[0].embedding.cloneWithType();
106+
span.close();
107+
return embedding;
90108
} on fail error e {
91-
return error ai:Error("Unable to obtain embedding for the provided chunk", e);
109+
ai:Error err = error ai:Error("Unable to obtain embedding for the provided chunk", e);
110+
span.close(err);
111+
return err;
92112
}
93113
}
94114

@@ -97,10 +117,18 @@ public distinct isolated client class EmbeddingProvider {
97117
# + chunks - The array of chunks to be converted into embeddings
98118
# + return - An array of embeddings on success, or an `ai:Error`
99119
isolated remote function batchEmbed(ai:Chunk[] chunks) returns ai:Embedding[]|ai:Error {
120+
observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId);
121+
span.addProvider("azure.ai.openai");
122+
100123
if !chunks.every(chunk => chunk is ai:TextChunk|ai:TextDocument) {
101-
return error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported");
124+
ai:Error err = error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported");
125+
span.close(err);
126+
return err;
102127
}
103128
do {
129+
string[] input = chunks.map(chunk => chunk.content.toString());
130+
span.addInputContent(input);
131+
104132
embeddings:InputItemsString[] inputItems = from ai:Chunk chunk in chunks
105133
select check chunk.content.cloneWithType();
106134
embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post(
@@ -109,11 +137,16 @@ public distinct isolated client class EmbeddingProvider {
109137
input: inputItems
110138
}
111139
);
112-
return
113-
from embeddings:Inline_response_200_data data in response.data
114-
select check data.embedding.cloneWithType();
140+
141+
span.addInputTokenCount(response.usage.prompt_tokens);
142+
ai:Embedding[] embeddings = from embeddings:Inline_response_200_data data in response.data
143+
select check data.embedding.cloneWithType();
144+
span.close();
145+
return embeddings;
115146
} on fail error e {
116-
return error ai:Error("Unable to obtain embedding for the provided document", e);
147+
ai:Error err = error("Unable to obtain embedding for the provided document", e);
148+
span.close(err);
149+
return err;
117150
}
118151
}
119152
}

ballerina/model-provider.bal

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// under the License.
1616

1717
import ballerina/ai;
18+
import ballerina/ai.observe;
1819
import ballerina/jballerina.java;
1920
import ballerinax/azure.openai.chat;
2021

@@ -89,6 +90,21 @@ public isolated client class OpenAiModelProvider {
8990
# + return - Function to be called, chat response or an error in-case of failures
9091
isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools, string? stop = ())
9192
returns ai:ChatAssistantMessage|ai:Error {
93+
observe:ChatSpan span = observe:createChatSpan(self.deploymentId);
94+
span.addProvider("azure.ai.openai");
95+
span.addOutputType(observe:TEXT);
96+
if stop is string {
97+
span.addStopSequence(stop);
98+
}
99+
span.addTemperature(self.temperature);
100+
json|ai:Error jsonMsg = check convertMessageToJson(messages);
101+
if jsonMsg is ai:Error {
102+
ai:Error err = error("Error while transforming input", jsonMsg);
103+
span.close(err);
104+
return err;
105+
}
106+
span.addInputMessages(jsonMsg);
107+
92108
chat:CreateChatCompletionRequest request = {
93109
stop,
94110
messages: check self.mapToChatCompletionRequestMessage(messages),
@@ -97,11 +113,14 @@ public isolated client class OpenAiModelProvider {
97113
};
98114
if tools.length() > 0 {
99115
request.functions = tools;
116+
span.addTools(tools);
100117
}
101118
chat:CreateChatCompletionResponse|error response =
102119
self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, request);
103120
if response is error {
104-
return error ai:LlmConnectionError("Error while connecting to the model", response);
121+
ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response);
122+
span.close(err);
123+
return err;
105124
}
106125

107126
record {|
@@ -113,24 +132,52 @@ public isolated client class OpenAiModelProvider {
113132
|}[]? choices = response.choices;
114133

115134
if choices is () || choices.length() == 0 {
116-
return error ai:LlmInvalidResponseError("Empty response from the model when using function call API");
135+
ai:Error err = error ai:LlmInvalidResponseError("Empty response from the model when using function call API");
136+
span.close(err);
137+
return err;
138+
}
139+
140+
string|int? responseId = response.id;
141+
if responseId is string {
142+
span.addResponseId(responseId);
143+
}
144+
int? inputTokens = response.usage?.prompt_tokens;
145+
if inputTokens is int {
146+
span.addInputTokenCount(inputTokens);
147+
}
148+
int? outputTokens = response.usage?.completion_tokens;
149+
if outputTokens is int {
150+
span.addOutputTokenCount(outputTokens);
151+
}
152+
string? finishReason = choices[0].finish_reason;
153+
if finishReason is string {
154+
span.addFinishReason(finishReason);
117155
}
156+
118157
chat:ChatCompletionResponseMessage? message = choices[0].message;
119158
ai:ChatAssistantMessage chatAssistantMessage = {role: ai:ASSISTANT, content: message?.content};
120159
chat:ChatCompletionFunctionCall? functionCall = message?.function_call;
121-
if functionCall is chat:ChatCompletionFunctionCall {
122-
chatAssistantMessage.toolCalls = [check self.mapToFunctionCall(functionCall)];
160+
if functionCall is () {
161+
span.addOutputMessages(chatAssistantMessage);
162+
span.close();
163+
return chatAssistantMessage;
164+
}
165+
ai:FunctionCall|ai:Error toolCall = check self.mapToFunctionCall(functionCall);
166+
if toolCall is ai:Error {
167+
span.close(toolCall);
168+
return toolCall;
123169
}
170+
chatAssistantMessage.toolCalls = [toolCall];
124171
return chatAssistantMessage;
125172
}
126173

127174
# Sends a chat request to the model and generates a value that belongs to the type
128175
# corresponding to the type descriptor argument.
129-
#
176+
#
130177
# + prompt - The prompt to use in the chat messages
131178
# + td - Type descriptor specifying the expected return type format
132179
# + return - Generates a value that belongs to the type, or an error if generation fails
133-
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
180+
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
134181
returns td|ai:Error = @java:Method {
135182
'class: "io.ballerina.lib.ai.azure.Generator"
136183
} external;
@@ -158,7 +205,7 @@ public isolated client class OpenAiModelProvider {
158205
assistantMessage["content"] = message?.content;
159206
}
160207
chatCompletionRequestMessages.push(assistantMessage);
161-
} else if message is ai:ChatFunctionMessage {
208+
} else {
162209
chatCompletionRequestMessages.push(message);
163210
}
164211
}
@@ -233,3 +280,14 @@ isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns s
233280
}
234281
return promptStr.trim();
235282
}
283+
284+
isolated function convertMessageToJson(ai:ChatMessage[]|ai:ChatMessage messages) returns json|ai:Error {
285+
if messages is ai:ChatMessage[] {
286+
return messages.'map(msg => msg is ai:ChatUserMessage|ai:ChatSystemMessage ? check convertMessageToJson(msg) : msg);
287+
}
288+
if messages is ai:ChatUserMessage|ai:ChatSystemMessage {
289+
290+
}
291+
return messages !is ai:ChatUserMessage|ai:ChatSystemMessage ? messages :
292+
{role: messages.role, content: check getChatMessageStringContent(messages.content), name: messages.name};
293+
}

ballerina/provider_utils.bal

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// under the License.
1616

1717
import ballerina/ai;
18+
import ballerina/ai.observe;
1819
import ballerina/constraint;
1920
import ballerina/lang.array;
2021
import ballerinax/azure.openai.chat;
@@ -105,17 +106,22 @@ isolated function getGetResultsToolChoice() returns chat:ChatCompletionNamedTool
105106
}
106107
};
107108

108-
isolated function getGetResultsTool(map<json> parameters) returns chat:ChatCompletionTool[]|error =>
109-
[
110-
{
111-
'type: FUNCTION,
112-
'function: {
113-
name: GET_RESULTS_TOOL,
114-
parameters: check parameters.cloneWithType(),
115-
description: "Tool to call with the response from a large language model (LLM) for a user prompt."
116-
}
109+
isolated function getGetResultsTool(map<json> parameters) returns chat:ChatCompletionTool[]|ai:Error {
110+
chat:ChatCompletionFunctionParameters|error toolParam = parameters.ensureType();
111+
if toolParam is error {
112+
return error("Error in generated schema: " + toolParam.message());
117113
}
118-
];
114+
return [
115+
{
116+
'type: FUNCTION,
117+
'function: {
118+
name: GET_RESULTS_TOOL,
119+
parameters: toolParam,
120+
description: "Tool to call with the response from a large language model (LLM) for a user prompt."
121+
}
122+
}
123+
];
124+
}
119125

120126
isolated function generateChatCreationContent(ai:Prompt prompt) returns DocumentContentPart[]|ai:Error {
121127
string[] & readonly strings = prompt.strings;
@@ -234,11 +240,19 @@ isolated function handleParseResponseError(error chatResponseError) returns erro
234240
isolated function generateLlmResponse(chat:Client llmClient, string deploymentId,
235241
string apiVersion, decimal temperature, int maxTokens, ai:Prompt prompt,
236242
typedesc<json> expectedResponseTypedesc) returns anydata|ai:Error {
237-
DocumentContentPart[] content = check generateChatCreationContent(prompt);
238-
ResponseSchema ResponseSchema = check getExpectedResponseSchema(expectedResponseTypedesc);
239-
chat:ChatCompletionTool[]|error tools = getGetResultsTool(ResponseSchema.schema);
240-
if tools is error {
241-
return error("Error in generated schema: " + tools.message());
243+
observe:GenerateContentSpan span = observe:createGenerateContentSpan(deploymentId);
244+
span.addTemperature(temperature);
245+
246+
DocumentContentPart[] content;
247+
ResponseSchema responseSchema;
248+
chat:ChatCompletionTool[] tools;
249+
do {
250+
content = check generateChatCreationContent(prompt);
251+
responseSchema = check getExpectedResponseSchema(expectedResponseTypedesc);
252+
tools = check getGetResultsTool(responseSchema.schema);
253+
} on fail ai:Error err {
254+
span.close(err);
255+
return err;
242256
}
243257

244258
chat:CreateChatCompletionRequest request = {
@@ -253,13 +267,43 @@ isolated function generateLlmResponse(chat:Client llmClient, string deploymentId
253267
max_tokens: maxTokens,
254268
tool_choice: getGetResultsToolChoice()
255269
};
270+
span.addInputMessages(request.messages.toJson());
256271

257272
chat:CreateChatCompletionResponse|error response =
258273
llmClient->/deployments/[deploymentId]/chat/completions.post(apiVersion, request);
259274
if response is error {
260-
return error("LLM call failed: " + response.message(), cause = response.cause(), detail = response.detail());
275+
ai:Error err = error("LLM call failed: " + response.message(), cause = response.cause(), detail = response.detail());
276+
span.close(err);
277+
return err;
261278
}
262279

280+
string? responseId = response.id;
281+
if responseId is string {
282+
span.addResponseId(responseId);
283+
}
284+
int? inputTokens = response.usage?.prompt_tokens;
285+
if inputTokens is int {
286+
span.addInputTokenCount(inputTokens);
287+
}
288+
int? outputTokens = response.usage?.completion_tokens;
289+
if outputTokens is int {
290+
span.addOutputTokenCount(outputTokens);
291+
}
292+
293+
anydata|ai:Error result = ensureAnydataResult(response, expectedResponseTypedesc,
294+
responseSchema.isOriginallyJsonObject, span);
295+
if result is ai:Error {
296+
span.close(result);
297+
return result;
298+
}
299+
span.addOutputMessages(result.toJson());
300+
span.close();
301+
return result;
302+
}
303+
304+
isolated function ensureAnydataResult(chat:CreateChatCompletionResponse response,
305+
typedesc<json> expectedResponseTypedesc, boolean isOriginallyJsonObject,
306+
observe:GenerateContentSpan span) returns anydata|ai:Error {
263307
record {
264308
chat:ChatCompletionResponseMessage message?;
265309
chat:ContentFilterChoiceResults content_filter_results?;
@@ -276,15 +320,18 @@ isolated function generateLlmResponse(chat:Client llmClient, string deploymentId
276320
if toolCalls is () || toolCalls.length() == 0 {
277321
return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
278322
}
323+
string? finishReason = choices[0].finish_reason;
324+
if finishReason is string {
325+
span.addFinishReason(finishReason);
326+
}
279327

280328
chat:ChatCompletionMessageToolCall tool = toolCalls[0];
281329
map<json>|error arguments = tool.'function.arguments.fromJsonStringWithType();
282330
if arguments is error {
283331
return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
284332
}
285333

286-
anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc,
287-
ResponseSchema.isOriginallyJsonObject);
334+
anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc, isOriginallyJsonObject);
288335
if res is error {
289336
return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${
290337
expectedResponseTypedesc.toBalString()}', found '${res.toBalString()}'`);

0 commit comments

Comments
 (0)