diff --git a/literalai/instrumentation/llamaindex/__init__.py b/literalai/instrumentation/llamaindex/__init__.py new file mode 100644 index 0000000..dd0a272 --- /dev/null +++ b/literalai/instrumentation/llamaindex/__init__.py @@ -0,0 +1,24 @@ +from literalai.client import LiteralClient +from llama_index.core.instrumentation import get_dispatcher + +from literalai.instrumentation.llamaindex.event_handler import LiteralEventHandler +from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler + +is_llamaindex_instrumented = False + +def instrument_llamaindex(client: "LiteralClient"): + global is_llamaindex_instrumented + if is_llamaindex_instrumented: + return + + root_dispatcher = get_dispatcher() + + span_handler = LiteralSpanHandler() + root_dispatcher.add_span_handler(span_handler) + + event_handler = LiteralEventHandler( + literal_client=client, llama_index_span_handler=span_handler + ) + root_dispatcher.add_event_handler(event_handler) + + is_llamaindex_instrumented = True diff --git a/literalai/instrumentation/llamaindex.py b/literalai/instrumentation/llamaindex/event_handler.py similarity index 63% rename from literalai/instrumentation/llamaindex.py rename to literalai/instrumentation/llamaindex/event_handler.py index 0b9fbb0..96813b0 100644 --- a/literalai/instrumentation/llamaindex.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -1,46 +1,57 @@ -import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast -from typing_extensions import TypedDict -from llama_index.core.base.llms.types import MessageRole -from llama_index.core.base.response.schema import Response, StreamingResponse -from llama_index.core.instrumentation import get_dispatcher +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast + from llama_index.core.instrumentation.event_handlers import BaseEventHandler from llama_index.core.instrumentation.events import BaseEvent +from pydantic import PrivateAttr + +from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler +from literalai.context import active_thread_var + +from llama_index.core.instrumentation.events.agent import ( + AgentChatWithStepStartEvent, + AgentChatWithStepEndEvent, + AgentRunStepStartEvent, + AgentRunStepEndEvent, +) from llama_index.core.instrumentation.events.embedding import ( - EmbeddingEndEvent, EmbeddingStartEvent, + EmbeddingEndEvent, ) -from llama_index.core.instrumentation.events.llm import ( - LLMChatEndEvent, - LLMChatStartEvent, -) + from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent from llama_index.core.instrumentation.events.retrieval import ( RetrievalEndEvent, RetrievalStartEvent, ) + +from llama_index.core.base.llms.types import MessageRole, ChatMessage +from llama_index.core.base.response.schema import Response, StreamingResponse + +from llama_index.core.instrumentation.events.llm import ( + LLMChatEndEvent, + LLMChatStartEvent, +) + from llama_index.core.instrumentation.events.synthesis import ( - GetResponseStartEvent, SynthesizeEndEvent, ) -from llama_index.core.instrumentation.span import SimpleSpan -from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler -from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion -from pydantic import PrivateAttr +from openai.types.chat import ChatCompletion, ChatCompletionChunk -from literalai.context import active_thread_var -from literalai.observability.generation import ChatGeneration, GenerationMessageRole +from literalai.observability.generation import ( + ChatGeneration, + GenerationMessage, + GenerationMessageRole, +) from literalai.observability.step import Step, StepType if TYPE_CHECKING: from literalai.client import LiteralClient -literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") - def convert_message_role(role: MessageRole) -> GenerationMessageRole: mapping = { @@ -88,6 +99,27 @@ def extract_document_info(nodes: List[NodeWithScore]): ] +def build_message_dict(message: ChatMessage): + message_dict = { + "role": convert_message_role(message.role), + "content": message.content, + } + + kwargs = message.additional_kwargs + + if kwargs: + if kwargs.get("tool_call_id", None): + message_dict["tool_call_id"] = kwargs.get("tool_call_id") + if kwargs.get("name", None): + message_dict["name"] = kwargs.get("name") + tool_calls = kwargs.get("tool_calls", []) + if len(tool_calls) > 0: + message_dict["tool_calls"] = [ + tool_call.to_dict() for tool_call in tool_calls + ] + return message_dict + + def create_generation(event: LLMChatStartEvent): model_dict = event.model_dict @@ -101,13 +133,14 @@ def create_generation(event: LLMChatStartEvent): "logprobs": model_dict.get("logprobs"), "top_logprobs": model_dict.get("top_logprobs"), }, - messages=[ - {"role": convert_message_role(message.role), "content": message.content} - for message in event.messages - ], + messages=[build_message_dict(message) for message in event.messages], ) +def extract_query(x: Union[str, QueryBundle]): + return x.query_str if isinstance(x, QueryBundle) else x + + class LiteralEventHandler(BaseEventHandler): """This class handles events coming from LlamaIndex.""" @@ -115,6 +148,8 @@ class LiteralEventHandler(BaseEventHandler): _span_handler: "LiteralSpanHandler" = PrivateAttr(...) runs: Dict[str, List[Step]] = {} streaming_run_ids: List[str] = [] + _standalone_step_id: Optional[str] = None + open_runs: List[Step] = [] class Config: arbitrary_types_allowed = True @@ -128,10 +163,23 @@ def __init__( object.__setattr__(self, "_client", literal_client) object.__setattr__(self, "_span_handler", llama_index_span_handler) - @classmethod - def class_name(cls) -> str: - """Class name.""" - return "LiteralEventHandler" + def _convert_message( + self, + message: ChatMessage, + ): + tool_calls = message.additional_kwargs.get("tool_calls") + msg = GenerationMessage( + name=getattr(message, "name", None), + role=convert_message_role(message.role), + content="", + ) + + msg["content"] = message.content + + if tool_calls: + msg["tool_calls"] = [tool_call.to_dict() for tool_call in tool_calls] + + return msg def handle(self, event: BaseEvent, **kwargs) -> None: """Logic for handling event.""" @@ -139,10 +187,46 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id = self._span_handler.get_thread_id(event.span_id) run_id = self._span_handler.get_run_id(event.span_id) - """The events are presented here roughly in chronological order""" + # AgentChatWithStep wraps several AgentRunStep events + # as the agent may want to perform multiple tool calls in a row. + if isinstance(event, AgentChatWithStepStartEvent) or isinstance( + event, AgentRunStepStartEvent + ): + run_name = ( + "Agent Chat" + if isinstance(event, AgentChatWithStepStartEvent) + else "Agent Step" + ) + parent_run_id = None + if len(self.open_runs) > 0: + parent_run_id = self.open_runs[-1].id + + agent_run_id = str(uuid.uuid4()) + + run = self._client.start_step( + name=run_name, + type="run", + id=agent_run_id, + parent_id=parent_run_id, + ) + + self.open_runs.append(run) + + if isinstance(event, AgentChatWithStepEndEvent) or isinstance( + event, AgentRunStepEndEvent + ): + try: + step = self.open_runs.pop() + except IndexError: + logging.error( + "[Literal] Error in Llamaindex instrumentation: AgentRunStepEndEvent called without an open run." + ) + if step: + step.end() + if isinstance(event, QueryStartEvent): active_thread = active_thread_var.get() - query = extract_query_from_bundle(event.query) + query = extract_query(event.query) if not active_thread or not active_thread.name: self._client.api.upsert_thread(id=thread_id, name=query) @@ -154,7 +238,8 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id=thread_id, content=query, ) - + + # Retrieval wraps the Embedding step in LlamaIndex if isinstance(event, RetrievalStartEvent): run = self._client.start_step( name="RAG", @@ -206,13 +291,14 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if run_id and retrieval_step: retrieved_documents = extract_document_info(event.nodes) - query = extract_query_from_bundle(event.str_or_query_bundle) + query = extract_query(event.str_or_query_bundle) retrieval_step.input = {"query": query} retrieval_step.output = {"retrieved_documents": retrieved_documents} retrieval_step.end() - if isinstance(event, GetResponseStartEvent): + # Only event where we create LLM steps + if isinstance(event, LLMChatStartEvent): if run_id: self._client.step() llm_step = self._client.start_step( @@ -221,23 +307,40 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id=thread_id, ) self.store_step(run_id=run_id, step=llm_step) - - if isinstance(event, LLMChatStartEvent): llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") - if run_id and llm_step: + if not run_id and not llm_step: + self._standalone_step_id = str(uuid.uuid4()) + llm_step = self._client.start_step( + name=event.model_dict.get("model", "LLM"), + type="llm", + id=self._standalone_step_id, + # Remove thread_id for standalone runs + ) + self.store_step(run_id=self._standalone_step_id, step=llm_step) + + if llm_step: generation = create_generation(event=event) llm_step.generation = generation llm_step.name = event.model_dict.get("model") + # Actual creation of the event happens upon ending the event if isinstance(event, LLMChatEndEvent): llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") + if not llm_step and self._standalone_step_id: + llm_step = self.get_first_step_of_type( + run_id=self._standalone_step_id, step_type="llm" + ) + response = event.response - if run_id and llm_step and response: + if llm_step and response: chat_completion = response.raw - if isinstance(chat_completion, ChatCompletion): + # ChatCompletionChunk needed for chat stream methods + if isinstance(chat_completion, ChatCompletion) or isinstance( + chat_completion, ChatCompletionChunk + ): usage = chat_completion.usage if isinstance(usage, CompletionUsage): @@ -246,6 +349,14 @@ def handle(self, event: BaseEvent, **kwargs) -> None: usage.completion_tokens ) + if self._standalone_step_id: + llm_step.generation.message_completion = ( + self._convert_message(response.message) + ) + + llm_step.end() + self._standalone_step_id = None + if isinstance(event, SynthesizeEndEvent): llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") run = self.get_first_step_of_type(run_id=run_id, step_type="run") @@ -305,166 +416,7 @@ def get_first_step_of_type( return None - -class SpanEntry(TypedDict): - id: str - parent_id: Optional[str] - root_id: Optional[str] - is_run_root: bool - - -class LiteralSpanHandler(BaseSpanHandler[SimpleSpan]): - """This class handles spans coming from LlamaIndex.""" - - spans: Dict[str, SpanEntry] = {} - @classmethod def class_name(cls) -> str: """Class name.""" - return "LiteralSpanHandler" - - def new_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - parent_span_id: Optional[str] = None, - tags: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ): - self.spans[id_] = { - "id": id_, - "parent_id": parent_span_id, - "root_id": None, - "is_run_root": self.is_run_root(instance, parent_span_id), - } - - if parent_span_id is not None: - self.spans[id_]["root_id"] = self.get_root_span_id(parent_span_id) - else: - self.spans[id_]["root_id"] = id_ - - def is_run_root( - self, instance: Optional[Any], parent_span_id: Optional[str] - ) -> bool: - """Returns True if the span is of type RetrieverQueryEngine, and it has no run root in its parent chain""" - if not isinstance(instance, RetrieverQueryEngine): - return False - - # Span is of correct type, we check that it doesn't have a run root in its parent chain - while parent_span_id: - parent_span = self.spans.get(parent_span_id) - - if not parent_span: - parent_span_id = None - continue - - if parent_span["is_run_root"]: - return False - - parent_span_id = parent_span["parent_id"] - - return True - - def get_root_span_id(self, span_id: Optional[str]): - """Finds the root span and returns its ID""" - if not span_id: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - while current_span["parent_id"] is not None: - current_span = self.spans.get(current_span["parent_id"]) - if current_span is None: - return None - - return current_span["id"] - - def get_run_id(self, span_id: Optional[str]): - """Go up the span chain to find a run_root, return its ID (or None)""" - if not span_id: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - while current_span: - if current_span["is_run_root"]: - return str(uuid.uuid5(literalai_uuid_namespace, current_span["id"])) - - parent_id = current_span["parent_id"] - - if parent_id: - current_span = self.spans.get(parent_id) - else: - current_span = None - - return None - - def get_thread_id(self, span_id: Optional[str]): - """Returns the root span ID as a thread ID""" - active_thread = active_thread_var.get() - - if active_thread: - return active_thread.id - - if span_id is None: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - root_id = current_span["root_id"] - - if not root_id: - return None - - root_span = self.spans.get(root_id) - - if root_span is None: - # span is already the root, uuid its own id - return str(uuid.uuid5(literalai_uuid_namespace, span_id)) - else: - # uuid the id of the root span - return str(uuid.uuid5(literalai_uuid_namespace, root_span["id"])) - - def prepare_to_exit_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - result: Optional[Any] = None, - **kwargs: Any, - ): - """Logic for preparing to exit a span.""" - if id in self.spans: - del self.spans[id_] - - def prepare_to_drop_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - err: Optional[BaseException] = None, - **kwargs: Any, - ): - """Logic for preparing to drop a span.""" - if id in self.spans: - del self.spans[id_] - - -def instrument_llamaindex(client: "LiteralClient"): - root_dispatcher = get_dispatcher() - span_handler = LiteralSpanHandler() - event_handler = LiteralEventHandler( - literal_client=client, llama_index_span_handler=span_handler - ) - root_dispatcher.add_event_handler(event_handler) - root_dispatcher.add_span_handler(span_handler) + return "LiteralEventHandler" diff --git a/literalai/instrumentation/llamaindex/span_handler.py b/literalai/instrumentation/llamaindex/span_handler.py new file mode 100644 index 0000000..7c5c9f1 --- /dev/null +++ b/literalai/instrumentation/llamaindex/span_handler.py @@ -0,0 +1,166 @@ +from typing_extensions import TypedDict +from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler +from llama_index.core.instrumentation.span import SimpleSpan +from typing import Any, Dict, Optional +from llama_index.core.query_engine import RetrieverQueryEngine +import uuid +from literalai.context import active_thread_var + +literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") + + +class SpanEntry(TypedDict): + id: str + parent_id: Optional[str] + root_id: Optional[str] + is_run_root: bool + + +class LiteralSpanHandler(BaseSpanHandler[SimpleSpan]): + """This class handles spans coming from LlamaIndex.""" + + spans: Dict[str, SpanEntry] = {} + + def __init__(self): + super().__init__() + + def new_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + parent_span_id: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + self.spans[id_] = { + "id": id_, + "parent_id": parent_span_id, + "root_id": None, + "is_run_root": self.is_run_root(instance, parent_span_id), + } + + if parent_span_id is not None: + self.spans[id_]["root_id"] = self.get_root_span_id(parent_span_id) + else: + self.spans[id_]["root_id"] = id_ + + def prepare_to_exit_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + result: Optional[Any] = None, + **kwargs: Any, + ): + """Logic for preparing to exit a span.""" + if id_ in self.spans: + del self.spans[id_] + + def prepare_to_drop_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + err: Optional[BaseException] = None, + **kwargs: Any, + ): + """Logic for preparing to drop a span.""" + if id_ in self.spans: + del self.spans[id_] + + def is_run_root( + self, instance: Optional[Any], parent_span_id: Optional[str] + ) -> bool: + """Returns True if the span is of type RetrieverQueryEngine, and it has no run root in its parent chain""" + if not isinstance(instance, RetrieverQueryEngine): + return False + + # Span is of correct type, we check that it doesn't have a run root in its parent chain + while parent_span_id: + parent_span = self.spans.get(parent_span_id) + + if not parent_span: + parent_span_id = None + continue + + if parent_span["is_run_root"]: + return False + + parent_span_id = parent_span["parent_id"] + + return True + + def get_root_span_id(self, span_id: Optional[str]): + """Finds the root span and returns its ID""" + if not span_id: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + while current_span["parent_id"] is not None: + current_span = self.spans.get(current_span["parent_id"]) + if current_span is None: + return None + + return current_span["id"] + + def get_run_id(self, span_id: Optional[str]): + """Go up the span chain to find a run_root, return its ID (or None)""" + if not span_id: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + while current_span: + if current_span["is_run_root"]: + return str(uuid.uuid5(literalai_uuid_namespace, current_span["id"])) + + parent_id = current_span["parent_id"] + + if parent_id: + current_span = self.spans.get(parent_id) + else: + current_span = None + + return None + + def get_thread_id(self, span_id: Optional[str]): + """Returns the root span ID as a thread ID""" + active_thread = active_thread_var.get() + + if active_thread: + return active_thread.id + + if span_id is None: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + root_id = current_span["root_id"] + + if not root_id: + return None + + root_span = self.spans.get(root_id) + + if root_span is None: + # span is already the root, uuid its own id + return str(uuid.uuid5(literalai_uuid_namespace, span_id)) + else: + # uuid the id of the root span + return str(uuid.uuid5(literalai_uuid_namespace, root_span["id"])) + + @classmethod + def class_name(cls) -> str: + """Class name.""" + return "LiteralSpanHandler"