From 975359d4fe772b8121e203b47c79b3e1b88f0fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Thu, 12 Dec 2024 09:28:05 +0100 Subject: [PATCH 01/10] refactor: deprecate instrumentation methods --- literalai/client.py | 4 ++++ literalai/instrumentation/openai.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/literalai/client.py b/literalai/client.py index 2156637..1d2f45d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -1,5 +1,6 @@ import os from typing import Any, Dict, List, Optional, Union +from typing_extensions import deprecated from literalai.api import AsyncLiteralAPI, LiteralAPI from literalai.callback.langchain_callback import get_langchain_callback @@ -92,18 +93,21 @@ def to_sync(self) -> "LiteralClient": else: return self # type: ignore + @deprecated("Use Literal.initialize instead") def instrument_openai(self): """ Instruments the OpenAI SDK so that all LLM calls are logged to Literal AI. """ instrument_openai(self.to_sync()) + @deprecated("Use Literal.initialize instead") def instrument_mistralai(self): """ Instruments the Mistral AI SDK so that all LLM calls are logged to Literal AI. """ instrument_mistralai(self.to_sync()) + @deprecated("Use Literal.initialize instead") def instrument_llamaindex(self): """ Instruments the Llama Index framework so that all RAG & LLM calls are logged to Literal AI. diff --git a/literalai/instrumentation/openai.py b/literalai/instrumentation/openai.py index 8922ced..b1554a4 100644 --- a/literalai/instrumentation/openai.py +++ b/literalai/instrumentation/openai.py @@ -10,7 +10,12 @@ from literalai.context import active_steps_var, active_thread_var from literalai.helper import ensure_values_serializable -from literalai.observability.generation import GenerationMessage, CompletionGeneration, ChatGeneration, GenerationType +from literalai.observability.generation import ( + GenerationMessage, + CompletionGeneration, + ChatGeneration, + GenerationType, +) from literalai.wrappers import AfterContext, BeforeContext, wrap_all logger = logging.getLogger(__name__) From ae005e3f17cb2576e3053be2fd27590ddd721ce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Thu, 12 Dec 2024 14:03:39 +0100 Subject: [PATCH 02/10] refactor: add basic support --- literalai/client.py | 6 +++ literalai/exporter.py | 89 +++++++++++++++++++++++++++++++ literalai/observability/step.py | 37 ++++++++++++- literalai/observability/thread.py | 12 +++++ requirements.txt | 3 +- 5 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 literalai/exporter.py diff --git a/literalai/client.py b/literalai/client.py index 1d2f45d..09a295d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -1,4 +1,5 @@ import os +from traceloop.sdk import Traceloop from typing import Any, Dict, List, Optional, Union from typing_extensions import deprecated @@ -11,6 +12,7 @@ experiment_item_run_decorator, ) from literalai.event_processor import EventProcessor +from literalai.exporter import LoggingSpanExporter from literalai.instrumentation.mistralai import instrument_mistralai from literalai.instrumentation.openai import instrument_openai from literalai.my_types import Environment @@ -123,6 +125,10 @@ def instrument_llamaindex(self): instrument_llamaindex(self.to_sync()) + @classmethod + def initialize(cls): + Traceloop.init(exporter=LoggingSpanExporter()) + def langchain_callback( self, to_ignore: Optional[List[str]] = None, diff --git a/literalai/exporter.py b/literalai/exporter.py new file mode 100644 index 0000000..f94cee9 --- /dev/null +++ b/literalai/exporter.py @@ -0,0 +1,89 @@ +from datetime import datetime, timezone +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult +from typing import Sequence +import logging + +from literalai.helper import utc_now +from literalai.observability.step import Step + +from literalai.context import active_root_run_var, active_steps_var, active_thread_var + + +class LoggingSpanExporter(SpanExporter): + def __init__(self, logger_name: str = "span_exporter"): + self.logger = logging.getLogger(logger_name) + self.logger.setLevel(logging.INFO) + + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + """Export the spans by logging them.""" + try: + for span in spans: + if ( + span.attributes + and span.attributes.get("llm.is_streaming", None) is not None + ): + step = self._create_step_from_span(span) + self.logger.info(f"Created step from span: {step.to_dict()}") + + return SpanExportResult.SUCCESS + except Exception as e: + self.logger.error(f"Failed to export spans: {e}") + return SpanExportResult.FAILURE + + def shutdown(self): + """Shuts down the exporter.""" + pass + + def force_flush(self, timeout_millis: float = 30000) -> bool: + """Force flush the exporter.""" + return True + + def _create_step_from_span(self, span: ReadableSpan) -> Step: + """Convert a span to a Step object""" + start_time = ( + datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).isoformat() + if span.start_time + else utc_now() + ) + end_time = ( + datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).isoformat() + if span.end_time + else utc_now() + ) + + self.logger.info(span.attributes) + parent_id = span.attributes.get("literal.parent_id") + thread_id = span.attributes.get("literal.thread_id") + root_run_id = span.attributes.get("literal.root_run_id") + + self.logger.info( + f"From span attributes - Parent ID: {parent_id}, Thread ID: {thread_id}, Root Run ID: {root_run_id}" + ) + + step = Step( + id=(str(span.context.span_id) if span.context else None), + name=span.name, + type="llm", + start_time=start_time, + end_time=end_time, + thread_id=thread_id, + parent_id=parent_id, + root_run_id=root_run_id, + ) + + if span.status.is_ok: + step.error = span.status.description or "Unknown error" + + # Handle input/output/generation based on span attributes + # (We'll implement this in the next iteration) + + return step diff --git a/literalai/observability/step.py b/literalai/observability/step.py index 8a3ee89..27d3c51 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -16,6 +16,7 @@ from pydantic import Field from pydantic.dataclasses import dataclass +from traceloop.sdk import Traceloop from typing_extensions import TypedDict if TYPE_CHECKING: @@ -69,7 +70,7 @@ class AttachmentDict(TypedDict, total=False): @dataclass(repr=False) class Score(Utils): """ - A score captures information about the quality of a step/experiment item. + A score captures information about the quality of a step/experiment item. It can be of type either: - HUMAN: to capture human feedback - CODE: to capture the result of a code execution (deterministic) @@ -127,9 +128,10 @@ def from_dict(cls, score_dict: ScoreDict) -> "Score": @dataclass(repr=False) class Attachment(Utils): """ - An attachment is an object that can be associated with a step. + An attachment is an object that can be associated with a step. It can be an image, a file, a video, etc. """ + step_id: Optional[str] = None thread_id: Optional[str] = None id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) @@ -522,6 +524,21 @@ async def __aenter__(self): if active_root_run_var.get() is None and self.step_type == "run": active_root_run_var.set(self.step) + Traceloop.set_association_properties( + { + "literal.thread_id": str(self.step.thread_id), + "literal.parent_id": self.step.id, + "literal.root_run_id": str(self.step.id), + } + ) + else: + Traceloop.set_association_properties( + { + "literal.thread_id": str(self.thread_id), + "literal.parent_id": self.step.id, + "literal.root_run_id": str(self.step.root_run_id), + } + ) return self.step @@ -549,6 +566,21 @@ def __enter__(self) -> Step: if active_root_run_var.get() is None and self.step_type == "run": active_root_run_var.set(self.step) + Traceloop.set_association_properties( + { + "literal.thread_id": str(self.step.thread_id), + "literal.parent_id": self.step.id, + "literal.root_run_id": str(self.step.id), + } + ) + else: + Traceloop.set_association_properties( + { + "literal.thread_id": str(self.thread_id), + "literal.parent_id": self.step.id, + "literal.root_run_id": str(self.step.root_run_id), + } + ) return self.step @@ -637,6 +669,7 @@ def sync_wrapper(*args, **kwargs): step.output = {"content": deepcopy(result)} except Exception: pass + return result return sync_wrapper diff --git a/literalai/observability/thread.py b/literalai/observability/thread.py index 87a22d4..a681f52 100644 --- a/literalai/observability/thread.py +++ b/literalai/observability/thread.py @@ -5,6 +5,8 @@ from functools import wraps from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypedDict +from traceloop.sdk import Traceloop + from literalai.context import active_thread_var from literalai.my_types import UserDict, Utils from literalai.observability.step import Step, StepDict @@ -188,6 +190,11 @@ def __call__(self, func): def __enter__(self) -> "Optional[Thread]": thread_id = self.thread_id if self.thread_id else str(uuid.uuid4()) active_thread_var.set(Thread(id=thread_id, name=self.name, **self.kwargs)) + Traceloop.set_association_properties( + { + "literal.thread_id": thread_id, + } + ) return active_thread_var.get() def __exit__(self, exc_type, exc_val, exc_tb): @@ -198,6 +205,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): async def __aenter__(self): thread_id = self.thread_id if self.thread_id else str(uuid.uuid4()) active_thread_var.set(Thread(id=thread_id, name=self.name, **self.kwargs)) + Traceloop.set_association_properties( + { + "literal.thread_id": thread_id, + } + ) return active_thread_var.get() async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/requirements.txt b/requirements.txt index 836bff6..efe782a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ packaging==23.2 httpx>=0.23.0 pydantic>=1,<3 openai>=1.0.0 -chevron>=0.14.0 \ No newline at end of file +chevron>=0.14.0 +traceloop-sdk>=0.33.9 From 70c6e8a10bd67748b59396b0930dc88f93fcaed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Mon, 16 Dec 2024 10:06:21 +0100 Subject: [PATCH 03/10] refactor: allow send llm query --- literalai/client.py | 32 +++- literalai/event_processor.py | 1 + literalai/exporter.py | 288 ++++++++++++++++++++++++++++++++--- 3 files changed, 295 insertions(+), 26 deletions(-) diff --git a/literalai/client.py b/literalai/client.py index 09a295d..928bb9d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -1,7 +1,10 @@ +import json import os from traceloop.sdk import Traceloop from typing import Any, Dict, List, Optional, Union from typing_extensions import deprecated +import io +from contextlib import redirect_stdout from literalai.api import AsyncLiteralAPI, LiteralAPI from literalai.callback.langchain_callback import get_langchain_callback @@ -125,9 +128,11 @@ def instrument_llamaindex(self): instrument_llamaindex(self.to_sync()) - @classmethod - def initialize(cls): - Traceloop.init(exporter=LoggingSpanExporter()) + def initialize(self): + with redirect_stdout(io.StringIO()): + Traceloop.init( + exporter=LoggingSpanExporter(event_processor=self.event_processor) + ) def langchain_callback( self, @@ -362,6 +367,27 @@ def get_current_root_run(self): """ return active_root_run_var.get() + def set_properties( + self, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + thread = active_thread_var.get() + root_run = active_root_run_var.get() + parent = active_steps_var.get()[-1] if active_steps_var.get() else None + + Traceloop.set_association_properties( + { + "literal.thread_id": str(thread.id) if thread else "None", + "literal.parent_id": str(parent.id) if parent else "None", + "literal.root_run_id": str(root_run.id) if root_run else "None", + "literal.name": str(name) if name else "None", + "literal.tags": json.dumps(tags) if tags else "None", + "literal.metadata": json.dumps(metadata) if metadata else "None", + } + ) + def reset_context(self): """ Resets the context, forgetting active steps & setting current thread to None. diff --git a/literalai/event_processor.py b/literalai/event_processor.py index aae1f61..988ab81 100644 --- a/literalai/event_processor.py +++ b/literalai/event_processor.py @@ -98,6 +98,7 @@ def _process_batch(self, batch: List): self.processing_counter -= len(batch) def flush_and_stop(self): + time.sleep(4) self.stop_event.set() if not self.disabled: self.processing_thread.join() diff --git a/literalai/exporter.py b/literalai/exporter.py index f94cee9..34b1f87 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -1,19 +1,25 @@ from datetime import datetime, timezone +import json from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult -from typing import Sequence +from typing import Dict, List, Optional, Sequence, cast import logging +from literalai.event_processor import EventProcessor from literalai.helper import utc_now +from literalai.observability.generation import GenerationType from literalai.observability.step import Step -from literalai.context import active_root_run_var, active_steps_var, active_thread_var - class LoggingSpanExporter(SpanExporter): - def __init__(self, logger_name: str = "span_exporter"): + def __init__( + self, + logger_name: str = "span_exporter", + event_processor: Optional[EventProcessor] = None, + ): self.logger = logging.getLogger(logger_name) self.logger.setLevel(logging.INFO) + self.event_processor = event_processor if not self.logger.handlers: handler = logging.StreamHandler() @@ -29,9 +35,11 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: for span in spans: if ( span.attributes - and span.attributes.get("llm.is_streaming", None) is not None + and span.attributes.get("gen_ai.request.model", None) is not None + and self.event_processor is not None ): step = self._create_step_from_span(span) + self.event_processor.add_event(step.to_dict()) self.logger.info(f"Created step from span: {step.to_dict()}") return SpanExportResult.SUCCESS @@ -61,29 +69,263 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: ) self.logger.info(span.attributes) - parent_id = span.attributes.get("literal.parent_id") - thread_id = span.attributes.get("literal.thread_id") - root_run_id = span.attributes.get("literal.root_run_id") + parent_id = None + thread_id = None + root_run_id = None + metadata = None + name = None + tags = None + generation_messages = None + generation_message_completion = None + generation_prompt = None + generation_completion = None + generation_type = None + generation_model = None + generation_provider = None + generation_max_tokens = None + generation_stream = None + generation_token_count = None + generation_input_token_count = None + generation_output_token_count = None + generation_frequency_penalty = None + generation_logit_bias = None + generation_logprobs = None + generation_top_logprobs = None + generation_n = None + generation_presence_penalty = None + generation_response_format = None + generation_seed = None + generation_stop = None + generation_temperature = None + generation_top_p = None + generation_tool_choice = None - self.logger.info( - f"From span attributes - Parent ID: {parent_id}, Thread ID: {thread_id}, Root Run ID: {root_run_id}" - ) + if span.attributes is not None: + self.logger.info(span.attributes) + parent_id = ( + span.attributes.get( + "traceloop.association.properties.literal.parent_id" + ) + if span.attributes.get( + "traceloop.association.properties.literal.parent_id" + ) + and span.attributes.get( + "traceloop.association.properties.literal.parent_id" + ) + != "None" + else None + ) + thread_id = ( + span.attributes.get( + "traceloop.association.properties.literal.thread_id" + ) + if span.attributes.get( + "traceloop.association.properties.literal.thread_id" + ) + and span.attributes.get( + "traceloop.association.properties.literal.thread_id" + ) + != "None" + else None + ) + root_run_id = ( + span.attributes.get( + "traceloop.association.properties.literal.root_run_id" + ) + if span.attributes.get( + "traceloop.association.properties.literal.root_run_id" + ) + and span.attributes.get( + "traceloop.association.properties.literal.root_run_id" + ) + != "None" + else None + ) + metadata = ( + self.extract_json( + str( + span.attributes.get( + "traceloop.association.properties.literal.metadata" + ) + ) + ) + if span.attributes.get( + "traceloop.association.properties.literal.metadata" + ) + and span.attributes.get( + "traceloop.association.properties.literal.metadata" + ) + != "None" + else None + ) + tags = ( + self.extract_json( + str( + span.attributes.get( + "traceloop.association.properties.literal.tags" + ) + ) + ) + if span.attributes.get("traceloop.association.properties.literal.tags") + and span.attributes.get("traceloop.association.properties.literal.tags") + != "None" + else None + ) + name = ( + span.attributes.get("traceloop.association.properties.literal.name") + if span.attributes.get("traceloop.association.properties.literal.name") + and span.attributes.get("traceloop.association.properties.literal.name") + != "None" + else None + ) + generation_type = span.attributes.get("llm.request.type") + generation_messages = ( + self.extract_messages(span.attributes) + if generation_type == "chat" + else None + ) + generation_message_completion = ( + self.extract_messages(span.attributes, "gen_ai.completion.")[0] + if generation_type == "chat" + else None + ) + generation_prompt = span.attributes.get("gen_ai.prompt.0.user") + generation_completion = span.attributes.get("gen_ai.completion.0.content") + generation_model = span.attributes.get("gen_ai.request.model") + generation_provider = span.attributes.get("gen_ai.system") + generation_max_tokens = span.attributes.get("gen_ai.request.max_tokens") + generation_stream = span.attributes.get("llm.is_streaming") + generation_token_count = span.attributes.get("llm.usage.total_tokens") + generation_input_token_count = span.attributes.get( + "gen_ai.usage.prompt_tokens" + ) + generation_output_token_count = span.attributes.get( + "gen_ai.usage.completion_tokens" + ) + # TODO: Validate settings + generation_frequency_penalty = span.attributes.get( + "gen_ai.request.frequency_penalty" + ) + generation_logit_bias = span.attributes.get("gen_ai.request.logit_bias") + generation_logprobs = span.attributes.get("gen_ai.request.logprobs") + generation_top_logprobs = span.attributes.get("gen_ai.request.top_logprobs") + generation_n = span.attributes.get("gen_ai.request.n") + generation_presence_penalty = span.attributes.get( + "gen_ai.request.presence_penalty" + ) + generation_response_format = span.attributes.get( + "gen_ai.request.response_format" + ) + generation_seed = span.attributes.get("gen_ai.request.seed") + generation_stop = span.attributes.get("gen_ai.request.stop") + generation_temperature = span.attributes.get("gen_ai.request.temperature") + generation_top_p = span.attributes.get("gen_ai.request.top_p") + generation_tool_choice = span.attributes.get("gen_ai.request.tool_choice") - step = Step( - id=(str(span.context.span_id) if span.context else None), - name=span.name, - type="llm", - start_time=start_time, - end_time=end_time, - thread_id=thread_id, - parent_id=parent_id, - root_run_id=root_run_id, + step = Step.from_dict( + { + "id": (str(span.context.span_id) if span.context else None), + "name": str(name) if name else span.name, + "type": "llm", + "metadata": cast(Dict, metadata), + "startTime": start_time, + "endTime": end_time, + "threadId": str(thread_id) if thread_id else None, + "parentId": str(parent_id) if parent_id else None, + "rootRunId": str(root_run_id) if root_run_id else None, + "input": ( + {"content": generation_messages} + if generation_type == "chat" + else {"content": generation_prompt} + ), + "output": ( + {"content": generation_message_completion} + if generation_type == "chat" + else {"content": generation_completion} + ), + "tags": cast(List, tags), + "generation": { + "prompt": ( + generation_prompt if generation_type == "completion" else None + ), + "completion": ( + generation_completion + if generation_type == "completion" + else None + ), + "type": ( + GenerationType.CHAT + if generation_type == "chat" + else GenerationType.COMPLETION + ), + "model": generation_model, + "provider": generation_provider, + "settings": { + "max_tokens": generation_max_tokens, + "frequency_penalty": generation_frequency_penalty, + "logit_bias": generation_logit_bias, + "logprobs": generation_logprobs, + "top_logprobs": generation_top_logprobs, + "n": generation_n, + "presence_penalty": generation_presence_penalty, + "response_format": generation_response_format, + "seed": generation_seed, + "stop": generation_stop, + "temperature": generation_temperature, + "top_p": generation_top_p, + "tool_choice": generation_tool_choice, + "stream": generation_stream, + }, + "tokenCount": generation_token_count, + "inputTokenCount": generation_input_token_count, + "outputTokenCount": generation_output_token_count, + "messages": generation_messages, + "messageCompletion": generation_message_completion, + }, + } ) - if span.status.is_ok: + if not span.status.is_ok: step.error = span.status.description or "Unknown error" - # Handle input/output/generation based on span attributes - # (We'll implement this in the next iteration) + # TODO: Add generation promptid + # TODO: Add generation variables + # TODO: ttFirstToken + # TODO: duration + # TODO: tokenThroughputInSeconds + # TODO: Add tools + # TODO: error check with gemini error return step + + def extract_messages( + self, data: Dict, prefix: str = "gen_ai.prompt." + ) -> List[Dict]: + messages = [] + index = 0 + + while True: + role_key = f"{prefix}{index}.role" + content_key = f"{prefix}{index}.content" + + if role_key not in data or content_key not in data: + break + + messages.append( + { + "role": data[role_key], + "content": self.extract_json(data[content_key]), + } + ) + + index += 1 + + return messages + + def extract_json(self, data: str) -> Dict | List | str: + try: + content = json.loads(data) + except Exception: + content = data + + return content From 95cdb03aeef95aec078876921e89e99c335a6466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Mon, 16 Dec 2024 17:46:58 +0100 Subject: [PATCH 04/10] refactor: rework method --- literalai/client.py | 3 +- literalai/exporter.py | 328 +++++++++++++----------------------------- 2 files changed, 105 insertions(+), 226 deletions(-) diff --git a/literalai/client.py b/literalai/client.py index 928bb9d..44426ca 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -131,7 +131,8 @@ def instrument_llamaindex(self): def initialize(self): with redirect_stdout(io.StringIO()): Traceloop.init( - exporter=LoggingSpanExporter(event_processor=self.event_processor) + disable_batch=True, + exporter=LoggingSpanExporter(event_processor=self.event_processor), ) def langchain_callback( diff --git a/literalai/exporter.py b/literalai/exporter.py index 34b1f87..2724e7d 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -5,10 +5,11 @@ from typing import Dict, List, Optional, Sequence, cast import logging + from literalai.event_processor import EventProcessor from literalai.helper import utc_now from literalai.observability.generation import GenerationType -from literalai.observability.step import Step +from literalai.observability.step import Step, StepDict class LoggingSpanExporter(SpanExporter): @@ -39,8 +40,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: and self.event_processor is not None ): step = self._create_step_from_span(span) - self.event_processor.add_event(step.to_dict()) - self.logger.info(f"Created step from span: {step.to_dict()}") + self.event_processor.add_event(cast(StepDict, step.to_dict())) return SpanExportResult.SUCCESS except Exception as e: @@ -49,14 +49,25 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: def shutdown(self): """Shuts down the exporter.""" - pass + if self.event_processor is not None: + return self.event_processor.flush_and_stop() def force_flush(self, timeout_millis: float = 30000) -> bool: """Force flush the exporter.""" return True + # # TODO: Add generation promptid + # # TODO: Add generation variables + # # TODO: Check missing variables + # # TODO: ttFirstToken + # # TODO: duration + # # TODO: tokenThroughputInSeconds + # # TODO: Add tools + # # TODO: error check with gemini error def _create_step_from_span(self, span: ReadableSpan) -> Step: """Convert a span to a Step object""" + attributes = span.attributes or {} + start_time = ( datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).isoformat() if span.start_time @@ -68,234 +79,101 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: else utc_now() ) - self.logger.info(span.attributes) - parent_id = None - thread_id = None - root_run_id = None - metadata = None - name = None - tags = None - generation_messages = None - generation_message_completion = None - generation_prompt = None - generation_completion = None - generation_type = None - generation_model = None - generation_provider = None - generation_max_tokens = None - generation_stream = None - generation_token_count = None - generation_input_token_count = None - generation_output_token_count = None - generation_frequency_penalty = None - generation_logit_bias = None - generation_logprobs = None - generation_top_logprobs = None - generation_n = None - generation_presence_penalty = None - generation_response_format = None - generation_seed = None - generation_stop = None - generation_temperature = None - generation_top_p = None - generation_tool_choice = None - - if span.attributes is not None: - self.logger.info(span.attributes) - parent_id = ( - span.attributes.get( - "traceloop.association.properties.literal.parent_id" - ) - if span.attributes.get( - "traceloop.association.properties.literal.parent_id" - ) - and span.attributes.get( - "traceloop.association.properties.literal.parent_id" - ) - != "None" + generation_type = attributes.get("llm.request.type") + is_chat = generation_type == "chat" + + span_props = { + "parent_id": attributes.get( + "traceloop.association.properties.literal.parent_id" + ), + "thread_id": attributes.get( + "traceloop.association.properties.literal.thread_id" + ), + "root_run_id": attributes.get( + "traceloop.association.properties.literal.root_run_id" + ), + "metadata": attributes.get( + "traceloop.association.properties.literal.metadata" + ), + "tags": attributes.get("traceloop.association.properties.literal.tags"), + "name": attributes.get("traceloop.association.properties.literal.name"), + } + + span_props = { + k: str(v) for k, v in span_props.items() if v is not None and v != "None" + } + + generation_content = { + "messages": ( + self.extract_messages(cast(Dict, attributes)) if is_chat else None + ), + "message_completion": ( + self.extract_messages(cast(Dict, attributes), "gen_ai.completion.")[0] + if is_chat else None - ) - thread_id = ( - span.attributes.get( - "traceloop.association.properties.literal.thread_id" - ) - if span.attributes.get( - "traceloop.association.properties.literal.thread_id" + ), + "prompt": attributes.get("gen_ai.prompt.0.user"), + "completion": attributes.get("gen_ai.completion.0.content"), + "model": attributes.get("gen_ai.request.model"), + "provider": attributes.get("gen_ai.system"), + } + generation_settings = { + "max_tokens": attributes.get("gen_ai.request.max_tokens"), + "stream": attributes.get("llm.is_streaming"), + "token_count": attributes.get("llm.usage.total_tokens"), + "input_token_count": attributes.get("gen_ai.usage.prompt_tokens"), + "output_token_count": attributes.get("gen_ai.usage.completion_tokens"), + "frequency_penalty": attributes.get("gen_ai.request.frequency_penalty"), + "presence_penalty": attributes.get("gen_ai.request.presence_penalty"), + "temperature": attributes.get("gen_ai.request.temperature"), + "top_p": attributes.get("gen_ai.request.top_p"), + } + + step_dict = { + "id": str(span.context.span_id) if span.context else None, + "name": span_props.get("name", span.name), + "type": "llm", + "metadata": self.extract_json(span_props.get("metadata", "{}")), + "startTime": start_time, + "endTime": end_time, + "threadId": span_props.get("thread_id"), + "parentId": span_props.get("parent_id"), + "rootRunId": span_props.get("root_run_id"), + "tags": self.extract_json(span_props.get("tags", "[]")), + "input": { + "content": ( + generation_content["messages"] + if is_chat + else generation_content["prompt"] ) - and span.attributes.get( - "traceloop.association.properties.literal.thread_id" + }, + "output": { + "content": ( + generation_content["message_completion"] + if is_chat + else generation_content["completion"] ) - != "None" - else None - ) - root_run_id = ( - span.attributes.get( - "traceloop.association.properties.literal.root_run_id" - ) - if span.attributes.get( - "traceloop.association.properties.literal.root_run_id" - ) - and span.attributes.get( - "traceloop.association.properties.literal.root_run_id" - ) - != "None" - else None - ) - metadata = ( - self.extract_json( - str( - span.attributes.get( - "traceloop.association.properties.literal.metadata" - ) - ) - ) - if span.attributes.get( - "traceloop.association.properties.literal.metadata" - ) - and span.attributes.get( - "traceloop.association.properties.literal.metadata" - ) - != "None" - else None - ) - tags = ( - self.extract_json( - str( - span.attributes.get( - "traceloop.association.properties.literal.tags" - ) - ) - ) - if span.attributes.get("traceloop.association.properties.literal.tags") - and span.attributes.get("traceloop.association.properties.literal.tags") - != "None" - else None - ) - name = ( - span.attributes.get("traceloop.association.properties.literal.name") - if span.attributes.get("traceloop.association.properties.literal.name") - and span.attributes.get("traceloop.association.properties.literal.name") - != "None" - else None - ) - generation_type = span.attributes.get("llm.request.type") - generation_messages = ( - self.extract_messages(span.attributes) - if generation_type == "chat" - else None - ) - generation_message_completion = ( - self.extract_messages(span.attributes, "gen_ai.completion.")[0] - if generation_type == "chat" - else None - ) - generation_prompt = span.attributes.get("gen_ai.prompt.0.user") - generation_completion = span.attributes.get("gen_ai.completion.0.content") - generation_model = span.attributes.get("gen_ai.request.model") - generation_provider = span.attributes.get("gen_ai.system") - generation_max_tokens = span.attributes.get("gen_ai.request.max_tokens") - generation_stream = span.attributes.get("llm.is_streaming") - generation_token_count = span.attributes.get("llm.usage.total_tokens") - generation_input_token_count = span.attributes.get( - "gen_ai.usage.prompt_tokens" - ) - generation_output_token_count = span.attributes.get( - "gen_ai.usage.completion_tokens" - ) - # TODO: Validate settings - generation_frequency_penalty = span.attributes.get( - "gen_ai.request.frequency_penalty" - ) - generation_logit_bias = span.attributes.get("gen_ai.request.logit_bias") - generation_logprobs = span.attributes.get("gen_ai.request.logprobs") - generation_top_logprobs = span.attributes.get("gen_ai.request.top_logprobs") - generation_n = span.attributes.get("gen_ai.request.n") - generation_presence_penalty = span.attributes.get( - "gen_ai.request.presence_penalty" - ) - generation_response_format = span.attributes.get( - "gen_ai.request.response_format" - ) - generation_seed = span.attributes.get("gen_ai.request.seed") - generation_stop = span.attributes.get("gen_ai.request.stop") - generation_temperature = span.attributes.get("gen_ai.request.temperature") - generation_top_p = span.attributes.get("gen_ai.request.top_p") - generation_tool_choice = span.attributes.get("gen_ai.request.tool_choice") - - step = Step.from_dict( - { - "id": (str(span.context.span_id) if span.context else None), - "name": str(name) if name else span.name, - "type": "llm", - "metadata": cast(Dict, metadata), - "startTime": start_time, - "endTime": end_time, - "threadId": str(thread_id) if thread_id else None, - "parentId": str(parent_id) if parent_id else None, - "rootRunId": str(root_run_id) if root_run_id else None, - "input": ( - {"content": generation_messages} - if generation_type == "chat" - else {"content": generation_prompt} - ), - "output": ( - {"content": generation_message_completion} - if generation_type == "chat" - else {"content": generation_completion} - ), - "tags": cast(List, tags), - "generation": { - "prompt": ( - generation_prompt if generation_type == "completion" else None - ), - "completion": ( - generation_completion - if generation_type == "completion" - else None - ), - "type": ( - GenerationType.CHAT - if generation_type == "chat" - else GenerationType.COMPLETION - ), - "model": generation_model, - "provider": generation_provider, - "settings": { - "max_tokens": generation_max_tokens, - "frequency_penalty": generation_frequency_penalty, - "logit_bias": generation_logit_bias, - "logprobs": generation_logprobs, - "top_logprobs": generation_top_logprobs, - "n": generation_n, - "presence_penalty": generation_presence_penalty, - "response_format": generation_response_format, - "seed": generation_seed, - "stop": generation_stop, - "temperature": generation_temperature, - "top_p": generation_top_p, - "tool_choice": generation_tool_choice, - "stream": generation_stream, - }, - "tokenCount": generation_token_count, - "inputTokenCount": generation_input_token_count, - "outputTokenCount": generation_output_token_count, - "messages": generation_messages, - "messageCompletion": generation_message_completion, - }, - } - ) + }, + "generation": { + "type": GenerationType.CHAT if is_chat else GenerationType.COMPLETION, + "prompt": generation_content["prompt"] if not is_chat else None, + "completion": generation_content["completion"] if not is_chat else None, + "model": generation_content["model"], + "provider": generation_content["provider"], + "settings": generation_settings, + "tokenCount": generation_settings["token_count"], + "inputTokenCount": generation_settings["input_token_count"], + "outputTokenCount": generation_settings["output_token_count"], + "messages": generation_content["messages"], + "messageCompletion": generation_content["message_completion"], + }, + } + + step = Step.from_dict(cast(StepDict, step_dict)) if not span.status.is_ok: step.error = span.status.description or "Unknown error" - # TODO: Add generation promptid - # TODO: Add generation variables - # TODO: ttFirstToken - # TODO: duration - # TODO: tokenThroughputInSeconds - # TODO: Add tools - # TODO: error check with gemini error - return step def extract_messages( From b63e9123d2e12648008f653a05b61233d2bf6862 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Tue, 17 Dec 2024 11:04:24 +0100 Subject: [PATCH 05/10] refactor: address missing issues --- examples/langchain_toolcall.py | 9 ++-- examples/langchain_variable.py | 8 ++-- examples/multimodal.py | 84 ++++++++++++++++++++++++++++++++++ examples/streaming.py | 2 +- literalai/client.py | 3 ++ literalai/event_processor.py | 1 - literalai/exporter.py | 69 +++++++++++++++++++++------- 7 files changed, 151 insertions(+), 25 deletions(-) create mode 100644 examples/multimodal.py diff --git a/examples/langchain_toolcall.py b/examples/langchain_toolcall.py index 6fc3939..f9fc60f 100644 --- a/examples/langchain_toolcall.py +++ b/examples/langchain_toolcall.py @@ -19,6 +19,7 @@ tools = [search] lai_client = LiteralClient() +lai_client.initialize() lai_prompt = lai_client.api.get_or_create_prompt( name="LC Agent", settings={ @@ -37,13 +38,13 @@ {"role": "assistant", "content": "{{agent_scratchpad}}"}, ], ) -prompt = lai_prompt.to_langchain_chat_prompt_template() +prompt = lai_prompt.to_langchain_chat_prompt_template( + additional_messages=[("placeholder", "{agent_scratchpad}")], +) agent: BaseSingleActionAgent = create_tool_calling_agent(model, tools, prompt) # type: ignore agent_executor = AgentExecutor(agent=agent, tools=tools) -cb = lai_client.langchain_callback() - # Replace with ainvoke for asynchronous execution. agent_executor.invoke( { @@ -56,5 +57,5 @@ ], "input": "whats the weather in sf?", }, - config=RunnableConfig(callbacks=[cb], run_name="Weather SF"), + config=RunnableConfig(run_name="Weather SF"), ) diff --git a/examples/langchain_variable.py b/examples/langchain_variable.py index c58af52..d0da94c 100644 --- a/examples/langchain_variable.py +++ b/examples/langchain_variable.py @@ -1,12 +1,13 @@ from langchain.chat_models import init_chat_model from literalai import LiteralClient -from langchain.schema.runnable.config import RunnableConfig + from dotenv import load_dotenv load_dotenv() lai = LiteralClient() +lai.initialize() prompt = lai.api.get_or_create_prompt( name="user intent", @@ -29,13 +30,14 @@ input_messages = messages.format_messages( user_message="The screen is cracked, there are scratches on the surface, and a component is missing." ) -cb = lai.langchain_callback() # Returns a langchain_openai.ChatOpenAI instance. gpt_4o = init_chat_model( # type: ignore model_provider=prompt.provider, **prompt.settings, ) -print(gpt_4o.invoke(input_messages, config=RunnableConfig(callbacks=[cb]))) + +lai.set_properties(prompt=prompt) +print(gpt_4o.invoke(input_messages)) lai.flush_and_stop() diff --git a/examples/multimodal.py b/examples/multimodal.py new file mode 100644 index 0000000..2f4a3b5 --- /dev/null +++ b/examples/multimodal.py @@ -0,0 +1,84 @@ +import base64 +import requests # type: ignore +import time + +from literalai import LiteralClient +from openai import OpenAI + +from dotenv import load_dotenv + +from literalai.observability.step import ScoreDict + +load_dotenv() + +openai_client = OpenAI() + +literalai_client = LiteralClient() +literalai_client.initialize() + + +def encode_image(url): + return base64.b64encode(requests.get(url).content) + + +@literalai_client.step(type="run") +def generate_answer(user_query, image_url): + literalai_client.set_properties( + name="foobar", + metadata={"foo": "bar"}, + tags=["foo", "bar"], + ) + completion = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": user_query}, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + }, + ], + max_tokens=300, + ) + return completion.choices[0].message.content + + +def main(): + with literalai_client.thread(name="Meal Analyzer") as thread: + welcome_message = ( + "Welcome to the meal analyzer, please upload an image of your plate!" + ) + literalai_client.message( + content=welcome_message, type="assistant_message", name="My Assistant" + ) + + user_query = "Is this a healthy meal?" + user_image = "https://www.eatthis.com/wp-content/uploads/sites/4/2021/05/healthy-plate.jpg" + user_step = literalai_client.message( + content=user_query, type="user_message", name="User" + ) + + time.sleep(1) # to make sure the user step has arrived at Literal AI + + literalai_client.api.create_attachment( + thread_id=thread.id, + step_id=user_step.id, + name="meal_image", + content=encode_image(user_image), + ) + + answer = generate_answer(user_query=user_query, image_url=user_image) + literalai_client.message( + content=answer, type="assistant_message", name="My Assistant" + ) + + +main() +# Network requests by the SDK are performed asynchronously. +# Invoke flush_and_stop() to guarantee the completion of all requests prior to the process termination. +# WARNING: If you run a continuous server, you should not use this method. +literalai_client.flush_and_stop() diff --git a/examples/streaming.py b/examples/streaming.py index ed47a3e..1c884a1 100644 --- a/examples/streaming.py +++ b/examples/streaming.py @@ -12,7 +12,7 @@ sdk = LiteralClient(batch_size=2) -sdk.instrument_openai() +sdk.initialize() @sdk.thread diff --git a/literalai/client.py b/literalai/client.py index 44426ca..bda138c 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -29,6 +29,7 @@ step_decorator, ) from literalai.observability.thread import ThreadContextManager, thread_decorator +from literalai.prompt_engineering.prompt import Prompt from literalai.requirements import check_all_requirements @@ -373,6 +374,7 @@ def set_properties( name: Optional[str] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + prompt: Optional[Prompt] = None, ): thread = active_thread_var.get() root_run = active_root_run_var.get() @@ -386,6 +388,7 @@ def set_properties( "literal.name": str(name) if name else "None", "literal.tags": json.dumps(tags) if tags else "None", "literal.metadata": json.dumps(metadata) if metadata else "None", + "literal.prompt": json.dumps(prompt.to_dict()) if prompt else "None", } ) diff --git a/literalai/event_processor.py b/literalai/event_processor.py index 988ab81..aae1f61 100644 --- a/literalai/event_processor.py +++ b/literalai/event_processor.py @@ -98,7 +98,6 @@ def _process_batch(self, batch: List): self.processing_counter -= len(batch) def flush_and_stop(self): - time.sleep(4) self.stop_event.set() if not self.disabled: self.processing_thread.join() diff --git a/literalai/exporter.py b/literalai/exporter.py index 2724e7d..a7b7c32 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone +from datetime import date, datetime, timezone import json +from annotated_types import Timezone from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult from typing import Dict, List, Optional, Sequence, cast @@ -10,6 +11,7 @@ from literalai.helper import utc_now from literalai.observability.generation import GenerationType from literalai.observability.step import Step, StepDict +from literalai.prompt_engineering.prompt import PromptDict class LoggingSpanExporter(SpanExporter): @@ -56,14 +58,8 @@ def force_flush(self, timeout_millis: float = 30000) -> bool: """Force flush the exporter.""" return True - # # TODO: Add generation promptid - # # TODO: Add generation variables - # # TODO: Check missing variables - # # TODO: ttFirstToken - # # TODO: duration - # # TODO: tokenThroughputInSeconds - # # TODO: Add tools - # # TODO: error check with gemini error + # TODO: error check with gemini error + # TODO: ttFirstToken def _create_step_from_span(self, span: ReadableSpan) -> Step: """Convert a span to a Step object""" attributes = span.attributes or {} @@ -78,6 +74,11 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: if span.end_time else utc_now() ) + duration, token_throughput = self._calculate_duration_and_throughput( + span.start_time, + span.end_time, + int(str(attributes.get("llm.usage.total_tokens", 0))), + ) generation_type = attributes.get("llm.request.type") is_chat = generation_type == "chat" @@ -103,19 +104,35 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: k: str(v) for k, v in span_props.items() if v is not None and v != "None" } + serialized_prompt = attributes.get( + "traceloop.association.properties.literal.prompt" + ) + prompt = cast( + Optional[PromptDict], + ( + self._extract_json(str(serialized_prompt)) + if serialized_prompt and serialized_prompt != "None" + else None + ), + ) + generation_content = { + "duration": duration, "messages": ( - self.extract_messages(cast(Dict, attributes)) if is_chat else None + self._extract_messages(cast(Dict, attributes)) if is_chat else None ), "message_completion": ( - self.extract_messages(cast(Dict, attributes), "gen_ai.completion.")[0] + self._extract_messages(cast(Dict, attributes), "gen_ai.completion.")[0] if is_chat else None ), "prompt": attributes.get("gen_ai.prompt.0.user"), + "promptId": prompt.get("id") if prompt else None, "completion": attributes.get("gen_ai.completion.0.content"), "model": attributes.get("gen_ai.request.model"), "provider": attributes.get("gen_ai.system"), + "tokenThroughputInSeconds": token_throughput, + "variables": prompt.get("variables") if prompt else None, } generation_settings = { "max_tokens": attributes.get("gen_ai.request.max_tokens"), @@ -133,13 +150,13 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: "id": str(span.context.span_id) if span.context else None, "name": span_props.get("name", span.name), "type": "llm", - "metadata": self.extract_json(span_props.get("metadata", "{}")), + "metadata": self._extract_json(span_props.get("metadata", "{}")), "startTime": start_time, "endTime": end_time, "threadId": span_props.get("thread_id"), "parentId": span_props.get("parent_id"), "rootRunId": span_props.get("root_run_id"), - "tags": self.extract_json(span_props.get("tags", "[]")), + "tags": self._extract_json(span_props.get("tags", "[]")), "input": { "content": ( generation_content["messages"] @@ -176,7 +193,7 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: return step - def extract_messages( + def _extract_messages( self, data: Dict, prefix: str = "gen_ai.prompt." ) -> List[Dict]: messages = [] @@ -188,11 +205,13 @@ def extract_messages( if role_key not in data or content_key not in data: break + if data[role_key] == "placeholder": + break messages.append( { "role": data[role_key], - "content": self.extract_json(data[content_key]), + "content": self._extract_json(data[content_key]), } ) @@ -200,10 +219,28 @@ def extract_messages( return messages - def extract_json(self, data: str) -> Dict | List | str: + def _extract_json(self, data: str) -> Dict | List | str: try: content = json.loads(data) except Exception: content = data return content + + def _calculate_duration_and_throughput( + self, + start_time_ns: Optional[int], + end_time_ns: Optional[int], + total_tokens: Optional[int], + ) -> tuple[float, Optional[float]]: + """Calculate duration in seconds and token throughput per second.""" + duration_ns = ( + end_time_ns - start_time_ns if start_time_ns and end_time_ns else 0 + ) + duration_seconds = duration_ns / 1e9 + + token_throughput = None + if total_tokens is not None and duration_seconds > 0: + token_throughput = total_tokens / duration_seconds + + return duration_seconds, token_throughput From e09d2612391667e62cdd69dff1a3fadeeefae3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Tue, 17 Dec 2024 13:52:18 +0100 Subject: [PATCH 06/10] fix: fix requirements --- examples/multimodal.py | 1 - literalai/exporter.py | 3 +-- requirements.txt | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/multimodal.py b/examples/multimodal.py index 2f4a3b5..f8e4a54 100644 --- a/examples/multimodal.py +++ b/examples/multimodal.py @@ -7,7 +7,6 @@ from dotenv import load_dotenv -from literalai.observability.step import ScoreDict load_dotenv() diff --git a/literalai/exporter.py b/literalai/exporter.py index a7b7c32..ac43549 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -1,6 +1,5 @@ -from datetime import date, datetime, timezone +from datetime import datetime, timezone import json -from annotated_types import Timezone from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult from typing import Dict, List, Optional, Sequence, cast diff --git a/requirements.txt b/requirements.txt index efe782a..f1d4670 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ httpx>=0.23.0 pydantic>=1,<3 openai>=1.0.0 chevron>=0.14.0 -traceloop-sdk>=0.33.9 +traceloop-sdk>=0.33.12 From 1cdace9ab2641520a23bd6917f65b0ad2c0c6342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Tue, 17 Dec 2024 15:35:28 +0100 Subject: [PATCH 07/10] fix: fix types --- literalai/exporter.py | 8 ++++---- mypy.ini | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/literalai/exporter.py b/literalai/exporter.py index ac43549..979730d 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -2,7 +2,7 @@ import json from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult -from typing import Dict, List, Optional, Sequence, cast +from typing import Dict, List, Optional, Sequence, Union, cast import logging @@ -149,13 +149,13 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: "id": str(span.context.span_id) if span.context else None, "name": span_props.get("name", span.name), "type": "llm", - "metadata": self._extract_json(span_props.get("metadata", "{}")), + "metadata": self._extract_json(str(span_props.get("metadata", "{}"))), "startTime": start_time, "endTime": end_time, "threadId": span_props.get("thread_id"), "parentId": span_props.get("parent_id"), "rootRunId": span_props.get("root_run_id"), - "tags": self._extract_json(span_props.get("tags", "[]")), + "tags": self._extract_json(str(span_props.get("tags", "[]"))), "input": { "content": ( generation_content["messages"] @@ -218,7 +218,7 @@ def _extract_messages( return messages - def _extract_json(self, data: str) -> Dict | List | str: + def _extract_json(self, data: str) -> Union[Dict, List, str]: try: content = json.loads(data) except Exception: diff --git a/mypy.ini b/mypy.ini index d3aaeb3..f54d425 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,4 +8,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-langchain_community.*] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True + +[mypy-traceloop.*] +ignore_missing_imports = True From ff7f64a1224085ca53dd662f93db37d01b0e0b8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Tue, 17 Dec 2024 17:44:55 +0100 Subject: [PATCH 08/10] feat: added workflow test --- examples/llamaindex_workflow.py | 48 +++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 examples/llamaindex_workflow.py diff --git a/examples/llamaindex_workflow.py b/examples/llamaindex_workflow.py new file mode 100644 index 0000000..8e06e82 --- /dev/null +++ b/examples/llamaindex_workflow.py @@ -0,0 +1,48 @@ +import asyncio +from llama_index.core.workflow import ( + Event, + StartEvent, + StopEvent, + Workflow, + step, +) +from llama_index.llms.openai import OpenAI +from literalai.client import LiteralClient + +lai_client = LiteralClient() +lai_client.initialize() + + +class JokeEvent(Event): + joke: str + + +class JokeFlow(Workflow): + llm = OpenAI() + + @step() + async def generate_joke(self, ev: StartEvent) -> JokeEvent: + topic = ev.topic + + prompt = f"Write your best joke about {topic}." + response = await self.llm.acomplete(prompt) + return JokeEvent(joke=str(response)) + + @step() + async def critique_joke(self, ev: JokeEvent) -> StopEvent: + joke = ev.joke + + prompt = f"Give a thorough analysis and critique of the following joke: {joke}" + response = await self.llm.acomplete(prompt) + return StopEvent(result=str(response)) + + +@lai_client.thread(name="JokeFlow") +async def main(): + w = JokeFlow(timeout=60, verbose=False) + result = await w.run(topic="pirates") + print(str(result)) + + +if __name__ == "__main__": + asyncio.run(main()) From 8178a2762a4c66b62d81b9deeb9c0d216274f7ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Tue, 17 Dec 2024 17:47:23 +0100 Subject: [PATCH 09/10] fix: fix types --- mypy.ini | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy.ini b/mypy.ini index f54d425..7e4bc03 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,3 +12,7 @@ ignore_missing_imports = True [mypy-traceloop.*] ignore_missing_imports = True + +[mypy-llama_index.*] +ignore_missing_imports = True + From f381baff89d2c8a5b82779041d3ac6b3f5779c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Fri, 20 Dec 2024 15:21:51 +0100 Subject: [PATCH 10/10] refactor: update todos --- examples/llamaindex_workflow.py | 9 ++++++++- literalai/exporter.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/llamaindex_workflow.py b/examples/llamaindex_workflow.py index 8e06e82..c580e40 100644 --- a/examples/llamaindex_workflow.py +++ b/examples/llamaindex_workflow.py @@ -16,6 +16,9 @@ class JokeEvent(Event): joke: str +class RewriteJoke(Event): + joke: str + class JokeFlow(Workflow): llm = OpenAI() @@ -29,7 +32,11 @@ async def generate_joke(self, ev: StartEvent) -> JokeEvent: return JokeEvent(joke=str(response)) @step() - async def critique_joke(self, ev: JokeEvent) -> StopEvent: + async def return_joke(self, ev: JokeEvent) -> RewriteJoke: + return RewriteJoke(joke=ev.joke + "What is funny?") + + @step() + async def critique_joke(self, ev: RewriteJoke) -> StopEvent: joke = ev.joke prompt = f"Give a thorough analysis and critique of the following joke: {joke}" diff --git a/literalai/exporter.py b/literalai/exporter.py index 979730d..28877e3 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -13,6 +13,8 @@ from literalai.prompt_engineering.prompt import PromptDict +# TODO: Suppport Gemini models https://github.com/traceloop/openllmetry/issues/2419 +# TODO: Support llamaindex workflow https://github.com/traceloop/openllmetry/pull/2421 class LoggingSpanExporter(SpanExporter): def __init__( self, @@ -57,8 +59,6 @@ def force_flush(self, timeout_millis: float = 30000) -> bool: """Force flush the exporter.""" return True - # TODO: error check with gemini error - # TODO: ttFirstToken def _create_step_from_span(self, span: ReadableSpan) -> Step: """Convert a span to a Step object""" attributes = span.attributes or {}