From cd89a5ed4c4eab41426743efb7f84a4826a066ed Mon Sep 17 00:00:00 2001 From: modenter Date: Mon, 23 Sep 2024 14:45:28 +0200 Subject: [PATCH 1/7] tmp: add instrumentation for standalone llm calls with llamaindex (without query engine) --- literalai/instrumentation/llamaindex.py | 49 +++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/literalai/instrumentation/llamaindex.py b/literalai/instrumentation/llamaindex.py index 0b9fbb0..e76c0d3 100644 --- a/literalai/instrumentation/llamaindex.py +++ b/literalai/instrumentation/llamaindex.py @@ -1,5 +1,6 @@ import logging import uuid +import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from typing_extensions import TypedDict from llama_index.core.base.llms.types import MessageRole @@ -115,6 +116,7 @@ class LiteralEventHandler(BaseEventHandler): _span_handler: "LiteralSpanHandler" = PrivateAttr(...) runs: Dict[str, List[Step]] = {} streaming_run_ids: List[str] = [] + _standalone_step_id: Optional[str] = None class Config: arbitrary_types_allowed = True @@ -133,6 +135,7 @@ def class_name(cls) -> str: """Class name.""" return "LiteralEventHandler" + def handle(self, event: BaseEvent, **kwargs) -> None: """Logic for handling event.""" try: @@ -141,6 +144,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: """The events are presented here roughly in chronological order""" if isinstance(event, QueryStartEvent): + print("received \033[93mQueryStart\033[0m signal, run_id:", run_id) active_thread = active_thread_var.get() query = extract_query_from_bundle(event.query) @@ -156,6 +160,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: ) if isinstance(event, RetrievalStartEvent): + print("received \033[93mRetrievalStart\033[0m signal, run_id:", run_id) run = self._client.start_step( name="RAG", type="run", @@ -175,6 +180,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self.store_step(run_id=run_id, step=retrieval_step) if isinstance(event, EmbeddingStartEvent): + print("received \033[93mEmbeddingStart\033[0m signal, run_id:", run_id) retrieval_step = self.get_first_step_of_type( run_id=run_id, step_type="retrieval" ) @@ -190,6 +196,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self.store_step(run_id=run_id, step=embedding_step) if isinstance(event, EmbeddingEndEvent): + print("received \033[93mEmbeddingEnd\033[0m signal, run_id:", run_id) embedding_step = self.get_first_step_of_type( run_id=run_id, step_type="embedding" ) @@ -200,6 +207,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: embedding_step.end() if isinstance(event, RetrievalEndEvent): + print("received \033[93mRetrievalEnd\033[0m signal, run_id:", run_id) retrieval_step = self.get_first_step_of_type( run_id=run_id, step_type="retrieval" ) @@ -213,6 +221,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: retrieval_step.end() if isinstance(event, GetResponseStartEvent): + print("received \033[93mGetResponseStart\033[0m signal, run_id:", run_id) if run_id: self._client.step() llm_step = self._client.start_step( @@ -223,21 +232,40 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self.store_step(run_id=run_id, step=llm_step) if isinstance(event, LLMChatStartEvent): + print("received \033[93mLLMChatStart\033[0m signal, run_id:", run_id) + print("\033[94m" + str(event) + "\033[0m") llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") - if run_id and llm_step: + if not run_id and not llm_step: + self._standalone_step_id = str(uuid.uuid4()) + llm_step = self._client.start_step( + name=event.model_dict.get("model", "LLM"), + type="llm", + id=self._standalone_step_id, + # Remove thread_id for standalone runs + ) + self.store_step(run_id=self._standalone_step_id, step=llm_step) + + if llm_step: generation = create_generation(event=event) llm_step.generation = generation llm_step.name = event.model_dict.get("model") if isinstance(event, LLMChatEndEvent): + print("received \033[93mLLMChatEnd\033[0m signal, run_id:", run_id) + print("\033[94m" + str(event) + "\033[0m") llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") + + if not llm_step and self._standalone_step_id: + llm_step = self.get_first_step_of_type(run_id=self._standalone_step_id, step_type="llm") + response = event.response - if run_id and llm_step and response: + if llm_step and response: chat_completion = response.raw if isinstance(chat_completion, ChatCompletion): + print("\033[92m" + str(chat_completion.__dict__) + "\033[0m") usage = chat_completion.usage if isinstance(usage, CompletionUsage): @@ -246,7 +274,19 @@ def handle(self, event: BaseEvent, **kwargs) -> None: usage.completion_tokens ) + if self._standalone_step_id: + llm_step.end() + self._standalone_step_id = None + + # Create a message for both standalone and regular runs + self._client.message( + type="assistant_message", + thread_id=thread_id if run_id else None, # Use thread_id only for regular runs + content=response.message.content, + ) + if isinstance(event, SynthesizeEndEvent): + print("received \033[93mSynthesizeEnd\033[0m signal, run_id:", run_id) llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") run = self.get_first_step_of_type(run_id=run_id, step_type="run") @@ -274,6 +314,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: ) if isinstance(event, QueryEndEvent): + print("received \033[93mQueryEnd\033[0m signal, run_id:", run_id) if run_id in self.runs: del self.runs[run_id] @@ -444,7 +485,7 @@ def prepare_to_exit_span( **kwargs: Any, ): """Logic for preparing to exit a span.""" - if id in self.spans: + if id_ in self.spans: del self.spans[id_] def prepare_to_drop_span( @@ -456,7 +497,7 @@ def prepare_to_drop_span( **kwargs: Any, ): """Logic for preparing to drop a span.""" - if id in self.spans: + if id_ in self.spans: del self.spans[id_] From 619cd88675a1c14e3030442519ac947f072fa12a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Wed, 25 Sep 2024 14:38:00 +0200 Subject: [PATCH 2/7] feat: handle llm.chat and predict_and_call --- literalai/instrumentation/llamaindex.py | 90 +++++++++++++++---------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/literalai/instrumentation/llamaindex.py b/literalai/instrumentation/llamaindex.py index e76c0d3..fbf2886 100644 --- a/literalai/instrumentation/llamaindex.py +++ b/literalai/instrumentation/llamaindex.py @@ -1,6 +1,5 @@ import logging import uuid -import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from typing_extensions import TypedDict from llama_index.core.base.llms.types import MessageRole @@ -8,6 +7,7 @@ from llama_index.core.instrumentation import get_dispatcher from llama_index.core.instrumentation.event_handlers import BaseEventHandler from llama_index.core.instrumentation.events import BaseEvent +from llama_index.core.base.llms.types import ChatMessage from llama_index.core.instrumentation.events.embedding import ( EmbeddingEndEvent, EmbeddingStartEvent, @@ -16,13 +16,13 @@ LLMChatEndEvent, LLMChatStartEvent, ) + from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent from llama_index.core.instrumentation.events.retrieval import ( RetrievalEndEvent, RetrievalStartEvent, ) from llama_index.core.instrumentation.events.synthesis import ( - GetResponseStartEvent, SynthesizeEndEvent, ) from llama_index.core.instrumentation.span import SimpleSpan @@ -34,7 +34,11 @@ from pydantic import PrivateAttr from literalai.context import active_thread_var -from literalai.observability.generation import ChatGeneration, GenerationMessageRole +from literalai.observability.generation import ( + ChatGeneration, + GenerationMessage, + GenerationMessageRole, +) from literalai.observability.step import Step, StepType if TYPE_CHECKING: @@ -89,6 +93,26 @@ def extract_document_info(nodes: List[NodeWithScore]): ] +def build_message_dict(message: ChatMessage): + message_dict = { + "role": convert_message_role(message.role), + "content": message.content, + } + kwargs = message.additional_kwargs + + if kwargs: + if kwargs.get("tool_call_id", None): + message_dict["tool_call_id"] = kwargs.get("tool_call_id") + if kwargs.get("name", None): + message_dict["name"] = kwargs.get("name") + tool_calls = kwargs.get("tool_calls", []) + if len(tool_calls) > 0: + message_dict["tool_calls"] = [ + tool_call.to_dict() for tool_call in tool_calls + ] + return message_dict + + def create_generation(event: LLMChatStartEvent): model_dict = event.model_dict @@ -102,10 +126,7 @@ def create_generation(event: LLMChatStartEvent): "logprobs": model_dict.get("logprobs"), "top_logprobs": model_dict.get("top_logprobs"), }, - messages=[ - {"role": convert_message_role(message.role), "content": message.content} - for message in event.messages - ], + messages=[build_message_dict(message) for message in event.messages], ) @@ -135,6 +156,23 @@ def class_name(cls) -> str: """Class name.""" return "LiteralEventHandler" + def _convert_message( + self, + message: ChatMessage, + ): + tool_calls = message.additional_kwargs.get("tool_calls") + msg = GenerationMessage( + name=getattr(message, "name", None), + role=convert_message_role(message.role), + content="", + ) + + msg["content"] = message.content + + if tool_calls: + msg["tool_calls"] = [tool_call.to_dict() for tool_call in tool_calls] + + return msg def handle(self, event: BaseEvent, **kwargs) -> None: """Logic for handling event.""" @@ -144,7 +182,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: """The events are presented here roughly in chronological order""" if isinstance(event, QueryStartEvent): - print("received \033[93mQueryStart\033[0m signal, run_id:", run_id) active_thread = active_thread_var.get() query = extract_query_from_bundle(event.query) @@ -160,7 +197,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: ) if isinstance(event, RetrievalStartEvent): - print("received \033[93mRetrievalStart\033[0m signal, run_id:", run_id) run = self._client.start_step( name="RAG", type="run", @@ -180,7 +216,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self.store_step(run_id=run_id, step=retrieval_step) if isinstance(event, EmbeddingStartEvent): - print("received \033[93mEmbeddingStart\033[0m signal, run_id:", run_id) retrieval_step = self.get_first_step_of_type( run_id=run_id, step_type="retrieval" ) @@ -196,7 +231,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self.store_step(run_id=run_id, step=embedding_step) if isinstance(event, EmbeddingEndEvent): - print("received \033[93mEmbeddingEnd\033[0m signal, run_id:", run_id) embedding_step = self.get_first_step_of_type( run_id=run_id, step_type="embedding" ) @@ -207,7 +241,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: embedding_step.end() if isinstance(event, RetrievalEndEvent): - print("received \033[93mRetrievalEnd\033[0m signal, run_id:", run_id) retrieval_step = self.get_first_step_of_type( run_id=run_id, step_type="retrieval" ) @@ -220,8 +253,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: retrieval_step.output = {"retrieved_documents": retrieved_documents} retrieval_step.end() - if isinstance(event, GetResponseStartEvent): - print("received \033[93mGetResponseStart\033[0m signal, run_id:", run_id) + if isinstance(event, LLMChatStartEvent): if run_id: self._client.step() llm_step = self._client.start_step( @@ -230,10 +262,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id=thread_id, ) self.store_step(run_id=run_id, step=llm_step) - - if isinstance(event, LLMChatStartEvent): - print("received \033[93mLLMChatStart\033[0m signal, run_id:", run_id) - print("\033[94m" + str(event) + "\033[0m") llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") if not run_id and not llm_step: @@ -252,20 +280,17 @@ def handle(self, event: BaseEvent, **kwargs) -> None: llm_step.name = event.model_dict.get("model") if isinstance(event, LLMChatEndEvent): - print("received \033[93mLLMChatEnd\033[0m signal, run_id:", run_id) - print("\033[94m" + str(event) + "\033[0m") llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") - if not llm_step and self._standalone_step_id: - llm_step = self.get_first_step_of_type(run_id=self._standalone_step_id, step_type="llm") + llm_step = self.get_first_step_of_type( + run_id=self._standalone_step_id, step_type="llm" + ) response = event.response if llm_step and response: chat_completion = response.raw - if isinstance(chat_completion, ChatCompletion): - print("\033[92m" + str(chat_completion.__dict__) + "\033[0m") usage = chat_completion.usage if isinstance(usage, CompletionUsage): @@ -274,19 +299,15 @@ def handle(self, event: BaseEvent, **kwargs) -> None: usage.completion_tokens ) - if self._standalone_step_id: - llm_step.end() - self._standalone_step_id = None + if self._standalone_step_id: + llm_step.generation.message_completion = ( + self._convert_message(response.message) + ) - # Create a message for both standalone and regular runs - self._client.message( - type="assistant_message", - thread_id=thread_id if run_id else None, # Use thread_id only for regular runs - content=response.message.content, - ) + llm_step.end() + self._standalone_step_id = None if isinstance(event, SynthesizeEndEvent): - print("received \033[93mSynthesizeEnd\033[0m signal, run_id:", run_id) llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") run = self.get_first_step_of_type(run_id=run_id, step_type="run") @@ -314,7 +335,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: ) if isinstance(event, QueryEndEvent): - print("received \033[93mQueryEnd\033[0m signal, run_id:", run_id) if run_id in self.runs: del self.runs[run_id] From 27f023f05db73624ba08712fce8bffbeb237d2d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Thu, 26 Sep 2024 16:21:33 +0200 Subject: [PATCH 3/7] fix: agent runs --- .../instrumentation/llamaindex/__init__.py | 17 ++ .../event_handler.py} | 248 +++++------------- .../llamaindex/span_handler.py | 182 +++++++++++++ 3 files changed, 264 insertions(+), 183 deletions(-) create mode 100644 literalai/instrumentation/llamaindex/__init__.py rename literalai/instrumentation/{llamaindex.py => llamaindex/event_handler.py} (70%) create mode 100644 literalai/instrumentation/llamaindex/span_handler.py diff --git a/literalai/instrumentation/llamaindex/__init__.py b/literalai/instrumentation/llamaindex/__init__.py new file mode 100644 index 0000000..bbc0662 --- /dev/null +++ b/literalai/instrumentation/llamaindex/__init__.py @@ -0,0 +1,17 @@ +from literalai.client import LiteralClient +from llama_index.core.instrumentation import get_dispatcher + +from literalai.instrumentation.llamaindex.event_handler import LiteralEventHandler +from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler + + +def instrument_llamaindex(client: "LiteralClient"): + root_dispatcher = get_dispatcher() + + span_handler = LiteralSpanHandler() + root_dispatcher.add_span_handler(span_handler) + + event_handler = LiteralEventHandler( + literal_client=client, llama_index_span_handler=span_handler + ) + root_dispatcher.add_event_handler(event_handler) diff --git a/literalai/instrumentation/llamaindex.py b/literalai/instrumentation/llamaindex/event_handler.py similarity index 70% rename from literalai/instrumentation/llamaindex.py rename to literalai/instrumentation/llamaindex/event_handler.py index fbf2886..b8dd05e 100644 --- a/literalai/instrumentation/llamaindex.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -1,20 +1,26 @@ -import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast -from typing_extensions import TypedDict -from llama_index.core.base.llms.types import MessageRole -from llama_index.core.base.response.schema import Response, StreamingResponse -from llama_index.core.instrumentation import get_dispatcher +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast + from llama_index.core.instrumentation.event_handlers import BaseEventHandler from llama_index.core.instrumentation.events import BaseEvent -from llama_index.core.base.llms.types import ChatMessage -from llama_index.core.instrumentation.events.embedding import ( - EmbeddingEndEvent, - EmbeddingStartEvent, -) +from pydantic import PrivateAttr + +from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler +from literalai.context import active_thread_var from llama_index.core.instrumentation.events.llm import ( LLMChatEndEvent, - LLMChatStartEvent, +) + +from llama_index.core.schema import QueryBundle +from llama_index.core.instrumentation.events.agent import ( + AgentChatWithStepStartEvent, + AgentChatWithStepEndEvent, + AgentRunStepStartEvent, + AgentRunStepEndEvent, +) +from llama_index.core.instrumentation.events.embedding import ( + EmbeddingEndEvent, ) from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent @@ -22,18 +28,23 @@ RetrievalEndEvent, RetrievalStartEvent, ) + +from llama_index.core.base.llms.types import MessageRole, ChatMessage +from llama_index.core.base.response.schema import Response, StreamingResponse + +from llama_index.core.instrumentation.events.llm import ( + LLMChatEndEvent, + LLMChatStartEvent, +) + from llama_index.core.instrumentation.events.synthesis import ( SynthesizeEndEvent, ) -from llama_index.core.instrumentation.span import SimpleSpan -from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler -from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from openai.types import CompletionUsage from openai.types.chat import ChatCompletion -from pydantic import PrivateAttr -from literalai.context import active_thread_var from literalai.observability.generation import ( ChatGeneration, GenerationMessage, @@ -44,8 +55,6 @@ if TYPE_CHECKING: from literalai.client import LiteralClient -literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") - def convert_message_role(role: MessageRole) -> GenerationMessageRole: mapping = { @@ -130,6 +139,14 @@ def create_generation(event: LLMChatStartEvent): ) +def print_blue(text: str): + print(f"\033[34m{text}\033[0m") + + +def get_query(a: Union[str, QueryBundle]): + return a.query_str if isinstance(a, QueryBundle) else a + + class LiteralEventHandler(BaseEventHandler): """This class handles events coming from LlamaIndex.""" @@ -138,6 +155,7 @@ class LiteralEventHandler(BaseEventHandler): runs: Dict[str, List[Step]] = {} streaming_run_ids: List[str] = [] _standalone_step_id: Optional[str] = None + open_runs: List[Step] = [] class Config: arbitrary_types_allowed = True @@ -176,11 +194,39 @@ def _convert_message( def handle(self, event: BaseEvent, **kwargs) -> None: """Logic for handling event.""" + tabs = self._span_handler.symbol * self._span_handler.tab_indent + print_blue(f"{tabs}{event.class_name()}") try: thread_id = self._span_handler.get_thread_id(event.span_id) run_id = self._span_handler.get_run_id(event.span_id) """The events are presented here roughly in chronological order""" + if isinstance(event, AgentChatWithStepStartEvent) or isinstance( + event, AgentRunStepStartEvent + ): + parent_run_id = None + if len(self.open_runs) > 0: + parent_run_id = self.open_runs[-1].id + + agent_run_id = str(uuid.uuid4()) + + run = self._client.start_step( + name="Agent", + type="run", + id=agent_run_id, + thread_id=thread_id, + parent_id=parent_run_id, + ) + + self.open_runs.append(run) + + if isinstance(event, AgentChatWithStepEndEvent) or isinstance( + event, AgentRunStepEndEvent + ): + step = self.open_runs.pop() + if step: + step.end() + if isinstance(event, QueryStartEvent): active_thread = active_thread_var.get() query = extract_query_from_bundle(event.query) @@ -365,167 +411,3 @@ def get_first_step_of_type( return step return None - - -class SpanEntry(TypedDict): - id: str - parent_id: Optional[str] - root_id: Optional[str] - is_run_root: bool - - -class LiteralSpanHandler(BaseSpanHandler[SimpleSpan]): - """This class handles spans coming from LlamaIndex.""" - - spans: Dict[str, SpanEntry] = {} - - @classmethod - def class_name(cls) -> str: - """Class name.""" - return "LiteralSpanHandler" - - def new_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - parent_span_id: Optional[str] = None, - tags: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ): - self.spans[id_] = { - "id": id_, - "parent_id": parent_span_id, - "root_id": None, - "is_run_root": self.is_run_root(instance, parent_span_id), - } - - if parent_span_id is not None: - self.spans[id_]["root_id"] = self.get_root_span_id(parent_span_id) - else: - self.spans[id_]["root_id"] = id_ - - def is_run_root( - self, instance: Optional[Any], parent_span_id: Optional[str] - ) -> bool: - """Returns True if the span is of type RetrieverQueryEngine, and it has no run root in its parent chain""" - if not isinstance(instance, RetrieverQueryEngine): - return False - - # Span is of correct type, we check that it doesn't have a run root in its parent chain - while parent_span_id: - parent_span = self.spans.get(parent_span_id) - - if not parent_span: - parent_span_id = None - continue - - if parent_span["is_run_root"]: - return False - - parent_span_id = parent_span["parent_id"] - - return True - - def get_root_span_id(self, span_id: Optional[str]): - """Finds the root span and returns its ID""" - if not span_id: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - while current_span["parent_id"] is not None: - current_span = self.spans.get(current_span["parent_id"]) - if current_span is None: - return None - - return current_span["id"] - - def get_run_id(self, span_id: Optional[str]): - """Go up the span chain to find a run_root, return its ID (or None)""" - if not span_id: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - while current_span: - if current_span["is_run_root"]: - return str(uuid.uuid5(literalai_uuid_namespace, current_span["id"])) - - parent_id = current_span["parent_id"] - - if parent_id: - current_span = self.spans.get(parent_id) - else: - current_span = None - - return None - - def get_thread_id(self, span_id: Optional[str]): - """Returns the root span ID as a thread ID""" - active_thread = active_thread_var.get() - - if active_thread: - return active_thread.id - - if span_id is None: - return None - - current_span = self.spans.get(span_id) - - if current_span is None: - return None - - root_id = current_span["root_id"] - - if not root_id: - return None - - root_span = self.spans.get(root_id) - - if root_span is None: - # span is already the root, uuid its own id - return str(uuid.uuid5(literalai_uuid_namespace, span_id)) - else: - # uuid the id of the root span - return str(uuid.uuid5(literalai_uuid_namespace, root_span["id"])) - - def prepare_to_exit_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - result: Optional[Any] = None, - **kwargs: Any, - ): - """Logic for preparing to exit a span.""" - if id_ in self.spans: - del self.spans[id_] - - def prepare_to_drop_span( - self, - id_: str, - bound_args: Any, - instance: Optional[Any] = None, - err: Optional[BaseException] = None, - **kwargs: Any, - ): - """Logic for preparing to drop a span.""" - if id_ in self.spans: - del self.spans[id_] - - -def instrument_llamaindex(client: "LiteralClient"): - root_dispatcher = get_dispatcher() - span_handler = LiteralSpanHandler() - event_handler = LiteralEventHandler( - literal_client=client, llama_index_span_handler=span_handler - ) - root_dispatcher.add_event_handler(event_handler) - root_dispatcher.add_span_handler(span_handler) diff --git a/literalai/instrumentation/llamaindex/span_handler.py b/literalai/instrumentation/llamaindex/span_handler.py new file mode 100644 index 0000000..5e78997 --- /dev/null +++ b/literalai/instrumentation/llamaindex/span_handler.py @@ -0,0 +1,182 @@ +from typing_extensions import TypedDict +from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler +from llama_index.core.instrumentation.span import SimpleSpan +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from llama_index.core.tools import BaseTool, FunctionTool +from llama_index.core.query_engine import RetrieverQueryEngine +import uuid +from literalai.context import active_thread_var + +literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") + + +def print_red(text: str): + print(f"\033[31m{text}\033[0m") + + +class SpanEntry(TypedDict): + id: str + parent_id: Optional[str] + root_id: Optional[str] + is_run_root: bool + + +class LiteralSpanHandler(BaseSpanHandler[SimpleSpan]): + """This class handles spans coming from LlamaIndex.""" + + spans: Dict[str, SpanEntry] = {} + tab_indent: int = 0 + symbol: str = " " + + def __init__(self): + super().__init__() + self.tab_indent = 0 + self.symbol = " " + + def new_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + parent_span_id: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + tabs = self.symbol * self.tab_indent + print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") + + self.spans[id_] = { + "id": id_, + "parent_id": parent_span_id, + "root_id": None, + "is_run_root": self.is_run_root(instance, parent_span_id), + } + + if parent_span_id is not None: + self.spans[id_]["root_id"] = self.get_root_span_id(parent_span_id) + else: + self.spans[id_]["root_id"] = id_ + + def prepare_to_exit_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + result: Optional[Any] = None, + **kwargs: Any, + ): + """Logic for preparing to exit a span.""" + tabs = self.symbol * self.tab_indent + print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") + if id_ in self.spans: + del self.spans[id_] + + def prepare_to_drop_span( + self, + id_: str, + bound_args: Any, + instance: Optional[Any] = None, + err: Optional[BaseException] = None, + **kwargs: Any, + ): + """Logic for preparing to drop a span.""" + tabs = self.symbol * self.tab_indent + print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") + if id_ in self.spans: + del self.spans[id_] + + def is_run_root( + self, instance: Optional[Any], parent_span_id: Optional[str] + ) -> bool: + """Returns True if the span is of type RetrieverQueryEngine, and it has no run root in its parent chain""" + if not isinstance(instance, RetrieverQueryEngine): + return False + + # Span is of correct type, we check that it doesn't have a run root in its parent chain + while parent_span_id: + parent_span = self.spans.get(parent_span_id) + + if not parent_span: + parent_span_id = None + continue + + if parent_span["is_run_root"]: + return False + + parent_span_id = parent_span["parent_id"] + + return True + + def get_root_span_id(self, span_id: Optional[str]): + """Finds the root span and returns its ID""" + if not span_id: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + while current_span["parent_id"] is not None: + current_span = self.spans.get(current_span["parent_id"]) + if current_span is None: + return None + + return current_span["id"] + + def get_run_id(self, span_id: Optional[str]): + """Go up the span chain to find a run_root, return its ID (or None)""" + if not span_id: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + while current_span: + if current_span["is_run_root"]: + return str(uuid.uuid5(literalai_uuid_namespace, current_span["id"])) + + parent_id = current_span["parent_id"] + + if parent_id: + current_span = self.spans.get(parent_id) + else: + current_span = None + + return None + + def get_thread_id(self, span_id: Optional[str]): + """Returns the root span ID as a thread ID""" + active_thread = active_thread_var.get() + + if active_thread: + return active_thread.id + + if span_id is None: + return None + + current_span = self.spans.get(span_id) + + if current_span is None: + return None + + root_id = current_span["root_id"] + + if not root_id: + return None + + root_span = self.spans.get(root_id) + + if root_span is None: + # span is already the root, uuid its own id + return str(uuid.uuid5(literalai_uuid_namespace, span_id)) + else: + # uuid the id of the root span + return str(uuid.uuid5(literalai_uuid_namespace, root_span["id"])) + + @classmethod + def class_name(cls) -> str: + """Class name.""" + return "LiteralSpanHandler" From d349603d21c75a296efdbd8692eb1fc786a1bc02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Thu, 26 Sep 2024 16:52:00 +0200 Subject: [PATCH 4/7] fix: remove prints and enable streaming --- .../llamaindex/event_handler.py | 33 +++++++++---------- .../llamaindex/span_handler.py | 19 +---------- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/literalai/instrumentation/llamaindex/event_handler.py b/literalai/instrumentation/llamaindex/event_handler.py index b8dd05e..37120bf 100644 --- a/literalai/instrumentation/llamaindex/event_handler.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -20,6 +20,7 @@ AgentRunStepEndEvent, ) from llama_index.core.instrumentation.events.embedding import ( + EmbeddingStartEvent, EmbeddingEndEvent, ) @@ -43,7 +44,7 @@ from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion, ChatCompletionChunk from literalai.observability.generation import ( ChatGeneration, @@ -107,6 +108,7 @@ def build_message_dict(message: ChatMessage): "role": convert_message_role(message.role), "content": message.content, } + kwargs = message.additional_kwargs if kwargs: @@ -139,12 +141,8 @@ def create_generation(event: LLMChatStartEvent): ) -def print_blue(text: str): - print(f"\033[34m{text}\033[0m") - - -def get_query(a: Union[str, QueryBundle]): - return a.query_str if isinstance(a, QueryBundle) else a +def extract_query(x: Union[str, QueryBundle]): + return x.query_str if isinstance(x, QueryBundle) else x class LiteralEventHandler(BaseEventHandler): @@ -169,11 +167,7 @@ def __init__( object.__setattr__(self, "_client", literal_client) object.__setattr__(self, "_span_handler", llama_index_span_handler) - @classmethod - def class_name(cls) -> str: - """Class name.""" - return "LiteralEventHandler" - + def _convert_message( self, message: ChatMessage, @@ -194,8 +188,6 @@ def _convert_message( def handle(self, event: BaseEvent, **kwargs) -> None: """Logic for handling event.""" - tabs = self._span_handler.symbol * self._span_handler.tab_indent - print_blue(f"{tabs}{event.class_name()}") try: thread_id = self._span_handler.get_thread_id(event.span_id) run_id = self._span_handler.get_run_id(event.span_id) @@ -229,7 +221,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if isinstance(event, QueryStartEvent): active_thread = active_thread_var.get() - query = extract_query_from_bundle(event.query) + query = extract_query(event.query) if not active_thread or not active_thread.name: self._client.api.upsert_thread(id=thread_id, name=query) @@ -293,7 +285,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if run_id and retrieval_step: retrieved_documents = extract_document_info(event.nodes) - query = extract_query_from_bundle(event.str_or_query_bundle) + query = extract_query(event.str_or_query_bundle) retrieval_step.input = {"query": query} retrieval_step.output = {"retrieved_documents": retrieved_documents} @@ -336,7 +328,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if llm_step and response: chat_completion = response.raw - if isinstance(chat_completion, ChatCompletion): + if isinstance(chat_completion, ChatCompletion) or isinstance(chat_completion, ChatCompletionChunk): usage = chat_completion.usage if isinstance(usage, CompletionUsage): @@ -350,6 +342,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self._convert_message(response.message) ) + llm_step.end() self._standalone_step_id = None @@ -411,3 +404,9 @@ def get_first_step_of_type( return step return None + + @classmethod + def class_name(cls) -> str: + """Class name.""" + return "LiteralEventHandler" + diff --git a/literalai/instrumentation/llamaindex/span_handler.py b/literalai/instrumentation/llamaindex/span_handler.py index 5e78997..142fd0e 100644 --- a/literalai/instrumentation/llamaindex/span_handler.py +++ b/literalai/instrumentation/llamaindex/span_handler.py @@ -1,19 +1,13 @@ from typing_extensions import TypedDict from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler from llama_index.core.instrumentation.span import SimpleSpan -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast -from llama_index.core.tools import BaseTool, FunctionTool +from typing import Any, Dict, Optional from llama_index.core.query_engine import RetrieverQueryEngine import uuid from literalai.context import active_thread_var literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") - -def print_red(text: str): - print(f"\033[31m{text}\033[0m") - - class SpanEntry(TypedDict): id: str parent_id: Optional[str] @@ -25,13 +19,9 @@ class LiteralSpanHandler(BaseSpanHandler[SimpleSpan]): """This class handles spans coming from LlamaIndex.""" spans: Dict[str, SpanEntry] = {} - tab_indent: int = 0 - symbol: str = " " def __init__(self): super().__init__() - self.tab_indent = 0 - self.symbol = " " def new_span( self, @@ -42,9 +32,6 @@ def new_span( tags: Optional[Dict[str, Any]] = None, **kwargs: Any, ): - tabs = self.symbol * self.tab_indent - print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") - self.spans[id_] = { "id": id_, "parent_id": parent_span_id, @@ -66,8 +53,6 @@ def prepare_to_exit_span( **kwargs: Any, ): """Logic for preparing to exit a span.""" - tabs = self.symbol * self.tab_indent - print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") if id_ in self.spans: del self.spans[id_] @@ -80,8 +65,6 @@ def prepare_to_drop_span( **kwargs: Any, ): """Logic for preparing to drop a span.""" - tabs = self.symbol * self.tab_indent - print_red(f"{tabs}{type(instance).__name__} #{id_[-6:]}") if id_ in self.spans: del self.spans[id_] From 468ae23ac6960ea88907bcc13795aa2771c89a7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Thu, 26 Sep 2024 17:18:11 +0200 Subject: [PATCH 5/7] fix: linting --- .../llamaindex/event_handler.py | 19 +++++++++---------- .../llamaindex/span_handler.py | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/literalai/instrumentation/llamaindex/event_handler.py b/literalai/instrumentation/llamaindex/event_handler.py index 37120bf..dc61511 100644 --- a/literalai/instrumentation/llamaindex/event_handler.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -8,11 +8,7 @@ from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler from literalai.context import active_thread_var -from llama_index.core.instrumentation.events.llm import ( - LLMChatEndEvent, -) -from llama_index.core.schema import QueryBundle from llama_index.core.instrumentation.events.agent import ( AgentChatWithStepStartEvent, AgentChatWithStepEndEvent, @@ -167,7 +163,6 @@ def __init__( object.__setattr__(self, "_client", literal_client) object.__setattr__(self, "_span_handler", llama_index_span_handler) - def _convert_message( self, message: ChatMessage, @@ -196,6 +191,11 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if isinstance(event, AgentChatWithStepStartEvent) or isinstance( event, AgentRunStepStartEvent ): + run_name = ( + "Agent Chat" + if isinstance(event, AgentChatWithStepStartEvent) + else "Agent Step" + ) parent_run_id = None if len(self.open_runs) > 0: parent_run_id = self.open_runs[-1].id @@ -203,10 +203,9 @@ def handle(self, event: BaseEvent, **kwargs) -> None: agent_run_id = str(uuid.uuid4()) run = self._client.start_step( - name="Agent", + name=run_name, type="run", id=agent_run_id, - thread_id=thread_id, parent_id=parent_run_id, ) @@ -328,7 +327,9 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if llm_step and response: chat_completion = response.raw - if isinstance(chat_completion, ChatCompletion) or isinstance(chat_completion, ChatCompletionChunk): + if isinstance(chat_completion, ChatCompletion) or isinstance( + chat_completion, ChatCompletionChunk + ): usage = chat_completion.usage if isinstance(usage, CompletionUsage): @@ -342,7 +343,6 @@ def handle(self, event: BaseEvent, **kwargs) -> None: self._convert_message(response.message) ) - llm_step.end() self._standalone_step_id = None @@ -409,4 +409,3 @@ def get_first_step_of_type( def class_name(cls) -> str: """Class name.""" return "LiteralEventHandler" - diff --git a/literalai/instrumentation/llamaindex/span_handler.py b/literalai/instrumentation/llamaindex/span_handler.py index 142fd0e..7c5c9f1 100644 --- a/literalai/instrumentation/llamaindex/span_handler.py +++ b/literalai/instrumentation/llamaindex/span_handler.py @@ -8,6 +8,7 @@ literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") + class SpanEntry(TypedDict): id: str parent_id: Optional[str] From 76b4e4412ae50dbc3a75fb538bb945534be9aadb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Fri, 27 Sep 2024 11:10:34 +0200 Subject: [PATCH 6/7] fix: add doc + try catch --- .../instrumentation/llamaindex/event_handler.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/literalai/instrumentation/llamaindex/event_handler.py b/literalai/instrumentation/llamaindex/event_handler.py index dc61511..96813b0 100644 --- a/literalai/instrumentation/llamaindex/event_handler.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -187,7 +187,8 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id = self._span_handler.get_thread_id(event.span_id) run_id = self._span_handler.get_run_id(event.span_id) - """The events are presented here roughly in chronological order""" + # AgentChatWithStep wraps several AgentRunStep events + # as the agent may want to perform multiple tool calls in a row. if isinstance(event, AgentChatWithStepStartEvent) or isinstance( event, AgentRunStepStartEvent ): @@ -214,7 +215,12 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if isinstance(event, AgentChatWithStepEndEvent) or isinstance( event, AgentRunStepEndEvent ): - step = self.open_runs.pop() + try: + step = self.open_runs.pop() + except IndexError: + logging.error( + "[Literal] Error in Llamaindex instrumentation: AgentRunStepEndEvent called without an open run." + ) if step: step.end() @@ -232,7 +238,8 @@ def handle(self, event: BaseEvent, **kwargs) -> None: thread_id=thread_id, content=query, ) - + + # Retrieval wraps the Embedding step in LlamaIndex if isinstance(event, RetrievalStartEvent): run = self._client.start_step( name="RAG", @@ -290,6 +297,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: retrieval_step.output = {"retrieved_documents": retrieved_documents} retrieval_step.end() + # Only event where we create LLM steps if isinstance(event, LLMChatStartEvent): if run_id: self._client.step() @@ -316,6 +324,7 @@ def handle(self, event: BaseEvent, **kwargs) -> None: llm_step.generation = generation llm_step.name = event.model_dict.get("model") + # Actual creation of the event happens upon ending the event if isinstance(event, LLMChatEndEvent): llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm") if not llm_step and self._standalone_step_id: @@ -327,6 +336,8 @@ def handle(self, event: BaseEvent, **kwargs) -> None: if llm_step and response: chat_completion = response.raw + + # ChatCompletionChunk needed for chat stream methods if isinstance(chat_completion, ChatCompletion) or isinstance( chat_completion, ChatCompletionChunk ): From 39f39292930da55583b140012558e17401749b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hugues=20de=20Saxc=C3=A9?= Date: Sat, 28 Sep 2024 10:21:55 +0200 Subject: [PATCH 7/7] fix: prevent multiple instrumentation of llamaindex --- literalai/instrumentation/llamaindex/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/literalai/instrumentation/llamaindex/__init__.py b/literalai/instrumentation/llamaindex/__init__.py index bbc0662..dd0a272 100644 --- a/literalai/instrumentation/llamaindex/__init__.py +++ b/literalai/instrumentation/llamaindex/__init__.py @@ -4,8 +4,13 @@ from literalai.instrumentation.llamaindex.event_handler import LiteralEventHandler from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler +is_llamaindex_instrumented = False def instrument_llamaindex(client: "LiteralClient"): + global is_llamaindex_instrumented + if is_llamaindex_instrumented: + return + root_dispatcher = get_dispatcher() span_handler = LiteralSpanHandler() @@ -15,3 +20,5 @@ def instrument_llamaindex(client: "LiteralClient"): literal_client=client, llama_index_span_handler=span_handler ) root_dispatcher.add_event_handler(event_handler) + + is_llamaindex_instrumented = True