Skip to content

feat: predefined query for extracting tool calls #7438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 28, 2025
7 changes: 6 additions & 1 deletion src/phoenix/session/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import phoenix.trace.v1 as pb
from phoenix.config import get_env_collector_endpoint, get_env_host, get_env_port
from phoenix.session.client import Client
from phoenix.trace.dsl.helpers import get_qa_with_reference, get_retrieved_documents
from phoenix.trace.dsl.helpers import (
get_called_tools,
get_qa_with_reference,
get_retrieved_documents,
)
from phoenix.trace.exporter import HttpExporter
from phoenix.trace.span_evaluations import Evaluations

__all__ = [
"get_retrieved_documents",
"get_qa_with_reference",
"get_called_tools",
"add_evaluations",
]

Expand Down
91 changes: 90 additions & 1 deletion src/phoenix/trace/dsl/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import warnings
from datetime import datetime
from typing import Optional, Protocol, Union, cast
from typing import Any, Iterable, Mapping, Optional, Protocol, Union, cast

import pandas as pd
from openinference.semconv.trace import DocumentAttributes, SpanAttributes
Expand All @@ -13,11 +14,16 @@
INPUT_VALUE = SpanAttributes.INPUT_VALUE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
LLM_FUNCTION_CALL = SpanAttributes.LLM_FUNCTION_CALL
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES


INPUT = {"input": INPUT_VALUE}
OUTPUT = {"output": OUTPUT_VALUE}
IO = {**INPUT, **OUTPUT}


IS_ROOT = "parent_id is None"
IS_LLM = "span_kind == 'LLM'"
IS_RETRIEVER = "span_kind == 'RETRIEVER'"
Expand Down Expand Up @@ -125,3 +131,86 @@ def get_qa_with_reference(
df_ref = pd.DataFrame({"reference": ref})
df_qa_ref = pd.concat([df_qa, df_ref], axis=1, join="inner").set_index("context.span_id")
return df_qa_ref


def get_called_tools(
obj: CanQuerySpans,
*,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
project_name: Optional[str] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT_IN_SECONDS,
function_name_only: bool = False,
) -> Optional[pd.DataFrame]:
"""Retrieve tool calls made by LLM spans within a specified time range.

This function queries LLM spans and extracts tool calls from their output messages.
It can return either just the function names or full function calls with arguments.

Args:
obj: An object that implements the CanQuerySpans protocol for querying spans.
start_time: Optional start time to filter spans. If None, no start time filter is applied.
end_time: Optional end time to filter spans. If None, no end time filter is applied.
project_name: Optional project name to filter spans. If None, uses the environment project name.
timeout: Optional timeout in seconds for the query. Defaults to DEFAULT_TIMEOUT_IN_SECONDS.
function_name_only: If True, returns only function names. If False, returns full function calls
with arguments. Defaults to False.

Returns:
A pandas DataFrame containing the tool calls, or None if no spans are found.
The DataFrame includes columns for input messages, output messages, and tool calls.
""" # noqa: E501
project_name = project_name or get_env_project_name()

def extract_tool_calls(outputs: list[dict[str, Any]]) -> Optional[list[str]]:
if not isinstance(outputs, list) or not outputs:
return None
ans = []
if isinstance(message := outputs[0].get("message"), Mapping) and isinstance(
tool_calls := message.get("tool_calls"), Iterable
):
for tool_call in tool_calls:
if not isinstance(tool_call, Mapping):
continue
if not isinstance(tc := tool_call.get("tool_call"), Mapping):
continue
if not isinstance(function := tc.get("function"), Mapping):
continue
if not isinstance(name := function.get("name"), str):
continue
if function_name_only:
ans.append(name)
continue
kwargs = {}
if isinstance(arguments := function.get("arguments"), str):
try:
kwargs = json.loads(arguments)
except Exception:
pass
kwargs_str = "" if not kwargs else ", ".join(f"{k}={v}" for k, v in kwargs.items())
ans.append(f"{name}({kwargs_str})")
return ans or None

df_qa = cast(
pd.DataFrame,
obj.query_spans(
SpanQuery()
.where(IS_LLM)
.select(
input=LLM_INPUT_MESSAGES,
output=LLM_OUTPUT_MESSAGES,
),
start_time=start_time,
end_time=end_time,
project_name=project_name,
timeout=timeout,
),
)

if df_qa is None:
print("No spans found.")
return None

df_qa["tool_call"] = df_qa["output"].apply(extract_tool_calls)

return df_qa
207 changes: 206 additions & 1 deletion tests/unit/trace/dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ async def default_project(db: DbSessionFactory) -> None:
attributes={
"input": {"value": "xyz"},
"retrieval": {
"documents": [{}, {}, {"document": {"content": "C", "score": 3}}],
"documents": [
{},
{},
{"document": {"content": "C", "score": 3}},
],
},
},
events=[],
Expand Down Expand Up @@ -143,6 +147,207 @@ async def default_project(db: DbSessionFactory) -> None:
)
.returning(models.Span.id)
)
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_rowid,
span_id="89101",
parent_id="2345",
name="llm span",
span_kind="LLM",
start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"),
attributes={
"llm": {
"input_messages": [
{
"message": {
"role": "user",
"content": "what is 2 times 3, and what is 2 plus 3",
}
}
],
"output_messages": [
{
"message": {
"role": "assistant",
"tool_calls": [
{
"tool_call": {
"id": "a",
"function": {
"name": "multiply",
"arguments": '{\n "a": 2,\n "b": 3\n}',
},
}
},
{
"tool_call": {
"id": "b",
"function": {
"name": "add",
"arguments": '{\n "a": 2,\n "b": 3\n}',
},
}
},
],
}
}
],
},
},
events=[],
status_code="OK",
status_message="okay",
cumulative_error_count=0,
cumulative_llm_token_count_prompt=10,
cumulative_llm_token_count_completion=5,
)
.returning(models.Span.id)
)
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_rowid,
span_id="91011",
parent_id="2345",
name="llm span",
span_kind="LLM",
start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"),
attributes={
"llm": {
"input_messages": [{"message": {"role": "user", "content": "call foo"}}],
"output_messages": [
{
"message": {
"role": "assistant",
"tool_calls": [
{
"tool_call": {
"id": "c",
"function": {
"name": "foo",
},
}
}
],
}
}
],
},
},
events=[],
status_code="OK",
status_message="okay",
cumulative_error_count=0,
cumulative_llm_token_count_prompt=8,
cumulative_llm_token_count_completion=4,
)
.returning(models.Span.id)
)
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_rowid,
span_id="111213",
parent_id="2345",
name="llm span",
span_kind="LLM",
start_time=datetime.fromisoformat("2021-01-01T00:00:25.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:35.000+00:00"),
attributes={
"llm": {
"input_messages": [{"message": {"role": "user", "content": "abc"}}],
"output_messages": [{"message": {"role": "assistant", "content": "xyz"}}],
}
},
events=[],
status_code="OK",
status_message="okay",
cumulative_error_count=0,
cumulative_llm_token_count_prompt=6,
cumulative_llm_token_count_completion=15,
)
.returning(models.Span.id)
)
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_rowid,
span_id="131415",
parent_id="2345",
name="llm span",
span_kind="LLM",
start_time=datetime.fromisoformat("2021-01-01T00:00:40.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:50.000+00:00"),
attributes={
"llm": {
"input_messages": [
{
"message": {
"role": "user",
"content": "test empty output",
}
}
],
"output_messages": None,
}
},
events=[],
status_code="OK",
status_message="okay",
cumulative_error_count=0,
cumulative_llm_token_count_prompt=5,
cumulative_llm_token_count_completion=5,
)
.returning(models.Span.id)
)
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_rowid,
span_id="171819",
parent_id="2345",
name="llm span",
span_kind="LLM",
start_time=datetime.fromisoformat("2021-01-01T00:01:10.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:01:20.000+00:00"),
attributes={
"llm": {
"input_messages": [
{
"message": {
"role": "user",
"content": "test invalid tool",
}
}
],
"output_messages": [
{
"message": {
"role": "assistant",
"tool_calls": [
{
"tool_call": {
"id": "invalid",
}
}
],
}
}
],
}
},
events=[],
status_code="OK",
status_message="okay",
cumulative_error_count=0,
cumulative_llm_token_count_prompt=5,
cumulative_llm_token_count_completion=5,
)
.returning(models.Span.id)
)


@pytest.fixture
Expand Down
Loading
Loading