From 89e49085c9efe9fcab13b1b9f1d8720c380fcc08 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 18 Feb 2025 11:58:45 -0800 Subject: [PATCH 1/2] fix: thread/step concurrency --- literalai/__init__.py | 7 +- literalai/api/asynchronous.py | 79 ++++++++++--------- literalai/api/base.py | 56 ++++--------- literalai/api/helpers/generation_helpers.py | 11 +-- literalai/api/synchronous.py | 79 ++++++++++--------- literalai/cache/shared_cache.py | 2 +- literalai/client.py | 13 +-- literalai/context.py | 2 +- literalai/evaluation/dataset.py | 1 + literalai/evaluation/dataset_experiment.py | 2 + literalai/evaluation/dataset_item.py | 1 + literalai/exporter.py | 18 +++-- .../instrumentation/llamaindex/__init__.py | 7 +- literalai/instrumentation/openai.py | 4 +- literalai/my_types.py | 9 ++- literalai/observability/message.py | 6 +- literalai/observability/step.py | 68 ++++++++++------ literalai/version.py | 2 +- literalai/wrappers.py | 7 +- setup.py | 2 +- 20 files changed, 200 insertions(+), 176 deletions(-) diff --git a/literalai/__init__.py b/literalai/__init__.py index bb17064f..cd17f856 100644 --- a/literalai/__init__.py +++ b/literalai/__init__.py @@ -1,8 +1,10 @@ from literalai.client import AsyncLiteralClient, LiteralClient from literalai.evaluation.dataset import Dataset +from literalai.evaluation.dataset_experiment import ( + DatasetExperiment, + DatasetExperimentItem, +) from literalai.evaluation.dataset_item import DatasetItem -from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem -from literalai.prompt_engineering.prompt import Prompt from literalai.my_types import * # noqa from literalai.observability.generation import ( BaseGeneration, @@ -13,6 +15,7 @@ from literalai.observability.message import Message from literalai.observability.step import Attachment, Score, Step from literalai.observability.thread import Thread +from literalai.prompt_engineering.prompt import Prompt from literalai.version import __version__ __all__ = [ diff --git a/literalai/api/asynchronous.py b/literalai/api/asynchronous.py index ddb85382..02ddd5f1 100644 --- a/literalai/api/asynchronous.py +++ b/literalai/api/asynchronous.py @@ -1,21 +1,11 @@ import logging import uuid +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast +import httpx from typing_extensions import deprecated -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - TypeVar, - Union, - cast, -) from literalai.api.base import BaseLiteralAPI, prepare_variables - from literalai.api.helpers.attachment_helpers import ( AttachmentUpload, create_attachment_helper, @@ -91,6 +81,7 @@ DatasetExperimentItem, ) from literalai.evaluation.dataset_item import DatasetItem +from literalai.my_types import PaginatedResponse, User from literalai.observability.filter import ( generations_filters, generations_order_by, @@ -102,12 +93,6 @@ threads_order_by, users_filters, ) -from literalai.observability.thread import Thread -from literalai.prompt_engineering.prompt import Prompt, ProviderSettings - -import httpx - -from literalai.my_types import PaginatedResponse, User from literalai.observability.generation import ( BaseGeneration, ChatGeneration, @@ -123,6 +108,8 @@ StepDict, StepType, ) +from literalai.observability.thread import Thread +from literalai.prompt_engineering.prompt import Prompt, ProviderSettings logger = logging.getLogger(__name__) @@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI): R = TypeVar("R") async def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 + self, + description: str, + query: str, + variables: Dict[str, Any], + timeout: Optional[int] = 10, ) -> Dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") @@ -166,8 +157,7 @@ def raise_error(error): json = response.json() except ValueError as e: raise_error( - f"""Failed to parse JSON response: { - e}, content: {response.content!r}""" + f"Failed to parse JSON response: {e}, content: {response.content!r}" ) if json.get("errors"): @@ -178,8 +168,7 @@ def raise_error(error): for value in json["data"].values(): if value and value.get("ok") is False: raise_error( - f"""Failed to {description}: { - value.get('message')}""" + f"""Failed to {description}: {value.get("message")}""" ) return json @@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: return response.json() except ValueError as e: raise ValueError( - f"""Failed to parse JSON response: { - e}, content: {response.content!r}""" + f"Failed to parse JSON response: {e}, content: {response.content!r}" ) + async def gql_helper( self, query: str, @@ -235,7 +224,9 @@ async def get_user( ) -> "User": return await self.gql_helper(*get_user_helper(id, identifier)) - async def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User": + async def create_user( + self, identifier: str, metadata: Optional[Dict] = None + ) -> "User": return await self.gql_helper(*create_user_helper(identifier, metadata)) async def update_user( @@ -245,7 +236,7 @@ async def update_user( async def delete_user(self, id: str) -> Dict: return await self.gql_helper(*delete_user_helper(id)) - + async def get_or_create_user( self, identifier: str, metadata: Optional[Dict] = None ) -> "User": @@ -273,7 +264,7 @@ async def get_threads( first, after, before, filters, order_by, step_types_to_keep ) ) - + async def list_threads( self, first: Optional[int] = None, @@ -491,7 +482,7 @@ async def create_attachment( thread_id = active_thread.id if not step_id: - if active_steps := active_steps_var.get([]): + if active_steps := active_steps_var.get(): step_id = active_steps[-1].id else: raise Exception("No step_id provided and no active step found.") @@ -532,7 +523,9 @@ async def create_attachment( response = await self.make_gql_call(description, query, variables) return process_response(response) - async def update_attachment(self, id: str, update_params: AttachmentUpload) -> "Attachment": + async def update_attachment( + self, id: str, update_params: AttachmentUpload + ) -> "Attachment": return await self.gql_helper(*update_attachment_helper(id, update_params)) async def get_attachment(self, id: str) -> Optional["Attachment"]: @@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict: # Step APIs # ################################################################################## - async def create_step( self, thread_id: Optional[str] = None, @@ -646,7 +638,7 @@ async def get_generations( return await self.gql_helper( *get_generations_helper(first, after, before, filters, order_by) ) - + async def create_generation( self, generation: Union["ChatGeneration", "CompletionGeneration"] ) -> Union["ChatGeneration", "CompletionGeneration"]: @@ -667,8 +659,10 @@ async def create_dataset( return await self.gql_helper( *create_dataset_helper(sync_api, name, description, metadata, type) ) - - async def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None) -> "Dataset": + + async def get_dataset( + self, id: Optional[str] = None, name: Optional[str] = None + ) -> "Dataset": sync_api = LiteralAPI(self.api_key, self.url) subpath, _, variables, process_response = get_dataset_helper( sync_api, id=id, name=name @@ -738,7 +732,7 @@ async def create_experiment_item( result.scores = await self.create_scores(experiment_item.scores) return result - + ################################################################################## # DatasetItem APIs # ################################################################################## @@ -753,7 +747,7 @@ async def create_dataset_item( return await self.gql_helper( *create_dataset_item_helper(dataset_id, input, expected_output, metadata) ) - + async def get_dataset_item(self, id: str) -> "DatasetItem": return await self.gql_helper(*get_dataset_item_helper(id)) @@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage( return await self.gql_helper(*create_prompt_lineage_helper(name, description)) @deprecated('Please use "get_or_create_prompt_lineage" instead.') - async def create_prompt_lineage(self, name: str, description: Optional[str] = None) -> Dict: + async def create_prompt_lineage( + self, name: str, description: Optional[str] = None + ) -> Dict: return await self.get_or_create_prompt_lineage(name, description) async def get_or_create_prompt( @@ -838,7 +834,14 @@ async def get_prompt( raise ValueError("At least the `id` or the `name` must be provided.") sync_api = LiteralAPI(self.api_key, self.url) - get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( + ( + get_prompt_query, + description, + variables, + process_response, + timeout, + cached_prompt, + ) = get_prompt_helper( api=sync_api, id=id, name=name, version=version, cache=self.cache ) diff --git a/literalai/api/base.py b/literalai/api/base.py index 2b5f8873..da347003 100644 --- a/literalai/api/base.py +++ b/literalai/api/base.py @@ -1,29 +1,16 @@ import os - from abc import ABC, abstractmethod -from typing import ( - Any, - Dict, - List, - Optional, - Union, -) +from typing import Any, Dict, List, Optional, Union from typing_extensions import deprecated -from literalai.my_types import Environment - +from literalai.api.helpers.attachment_helpers import AttachmentUpload +from literalai.api.helpers.prompt_helpers import PromptRollout +from literalai.api.helpers.score_helpers import ScoreUpdate from literalai.cache.shared_cache import SharedCache from literalai.evaluation.dataset import DatasetType -from literalai.evaluation.dataset_experiment import ( - DatasetExperimentItem, -) -from literalai.api.helpers.attachment_helpers import ( - AttachmentUpload) -from literalai.api.helpers.score_helpers import ( - ScoreUpdate, -) - +from literalai.evaluation.dataset_experiment import DatasetExperimentItem +from literalai.my_types import Environment from literalai.observability.filter import ( generations_filters, generations_order_by, @@ -35,24 +22,14 @@ threads_order_by, users_filters, ) -from literalai.prompt_engineering.prompt import ProviderSettings - - -from literalai.api.helpers.prompt_helpers import ( - PromptRollout) - from literalai.observability.generation import ( ChatGeneration, CompletionGeneration, GenerationMessage, ) -from literalai.observability.step import ( - ScoreDict, - ScoreType, - Step, - StepDict, - StepType, -) +from literalai.observability.step import ScoreDict, ScoreType, Step, StepDict, StepType +from literalai.prompt_engineering.prompt import ProviderSettings + def prepare_variables(variables: Dict[str, Any]) -> Dict[str, Any]: """ @@ -72,6 +49,7 @@ def handle_bytes(item): return handle_bytes(variables) + class BaseLiteralAPI(ABC): def __init__( self, @@ -676,7 +654,7 @@ def delete_step( @abstractmethod def send_steps(self, steps: List[Union[StepDict, "Step"]]): """ - Sends a list of steps to process. + Sends a list of steps to process. Step ingestion happens asynchronously if you configured a cache. See [Cache Configuration](https://docs.literalai.com/self-hosting/deployment#4-cache-configuration-optional). Args: @@ -773,9 +751,7 @@ def create_dataset( pass @abstractmethod - def get_dataset( - self, id: Optional[str] = None, name: Optional[str] = None - ): + def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None): """ Retrieves a dataset by its ID or name. @@ -846,9 +822,7 @@ def create_experiment( pass @abstractmethod - def create_experiment_item( - self, experiment_item: DatasetExperimentItem - ): + def create_experiment_item(self, experiment_item: DatasetExperimentItem): """ Creates an experiment item within an existing experiment. @@ -1065,9 +1039,7 @@ def get_prompt_ab_testing(self, name: str): pass @abstractmethod - def update_prompt_ab_testing( - self, name: str, rollouts: List[PromptRollout] - ): + def update_prompt_ab_testing(self, name: str, rollouts: List[PromptRollout]): """ Update the A/B testing configuration for a prompt lineage. diff --git a/literalai/api/helpers/generation_helpers.py b/literalai/api/helpers/generation_helpers.py index 0a287d60..7e08b958 100644 --- a/literalai/api/helpers/generation_helpers.py +++ b/literalai/api/helpers/generation_helpers.py @@ -1,12 +1,13 @@ from typing import Any, Dict, Optional, Union +from literalai.api.helpers import gql +from literalai.my_types import PaginatedResponse from literalai.observability.filter import generations_filters, generations_order_by -from literalai.my_types import ( - PaginatedResponse, +from literalai.observability.generation import ( + BaseGeneration, + ChatGeneration, + CompletionGeneration, ) -from literalai.observability.generation import BaseGeneration, CompletionGeneration, ChatGeneration - -from literalai.api.helpers import gql def get_generations_helper( diff --git a/literalai/api/synchronous.py b/literalai/api/synchronous.py index 43455ee8..2933454d 100644 --- a/literalai/api/synchronous.py +++ b/literalai/api/synchronous.py @@ -1,21 +1,11 @@ import logging import uuid +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast +import httpx from typing_extensions import deprecated -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - TypeVar, - Union, - cast, -) from literalai.api.base import BaseLiteralAPI, prepare_variables - from literalai.api.helpers.attachment_helpers import ( AttachmentUpload, create_attachment_helper, @@ -90,6 +80,7 @@ DatasetExperimentItem, ) from literalai.evaluation.dataset_item import DatasetItem +from literalai.my_types import PaginatedResponse, User from literalai.observability.filter import ( generations_filters, generations_order_by, @@ -101,12 +92,6 @@ threads_order_by, users_filters, ) -from literalai.observability.thread import Thread -from literalai.prompt_engineering.prompt import Prompt, ProviderSettings - -import httpx - -from literalai.my_types import PaginatedResponse, User from literalai.observability.generation import ( BaseGeneration, ChatGeneration, @@ -122,6 +107,8 @@ StepDict, StepType, ) +from literalai.observability.thread import Thread +from literalai.prompt_engineering.prompt import Prompt, ProviderSettings logger = logging.getLogger(__name__) @@ -140,7 +127,11 @@ class LiteralAPI(BaseLiteralAPI): R = TypeVar("R") def make_gql_call( - self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10 + self, + description: str, + query: str, + variables: dict[str, Any], + timeout: Optional[int] = 10, ) -> dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") @@ -164,8 +155,7 @@ def raise_error(error): json = response.json() except ValueError as e: raise_error( - f"""Failed to parse JSON response: { - e}, content: {response.content!r}""" + f"Failed to parse JSON response: {e}, content: {response.content!r}" ) if json.get("errors"): @@ -176,8 +166,7 @@ def raise_error(error): for value in json["data"].values(): if value and value.get("ok") is False: raise_error( - f"""Failed to {description}: { - value.get('message')}""" + f"""Failed to {description}: {value.get("message")}""" ) return json @@ -202,8 +191,7 @@ def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: return response.json() except ValueError as e: raise ValueError( - f"""Failed to parse JSON response: { - e}, content: {response.content!r}""" + f"Failed to parse JSON response: {e}, content: {response.content!r}" ) def gql_helper( @@ -230,7 +218,9 @@ def get_users( ) -> PaginatedResponse["User"]: return self.gql_helper(*get_users_helper(first, after, before, filters)) - def get_user(self, id: Optional[str] = None, identifier: Optional[str] = None) -> "User": + def get_user( + self, id: Optional[str] = None, identifier: Optional[str] = None + ) -> "User": return self.gql_helper(*get_user_helper(id, identifier)) def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User": @@ -244,7 +234,9 @@ def update_user( def delete_user(self, id: str) -> Dict: return self.gql_helper(*delete_user_helper(id)) - def get_or_create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User": + def get_or_create_user( + self, identifier: str, metadata: Optional[Dict] = None + ) -> "User": user = self.get_user(identifier=identifier) if user: return user @@ -486,7 +478,7 @@ def create_attachment( thread_id = active_thread.id if not step_id: - if active_steps := active_steps_var.get([]): + if active_steps := active_steps_var.get(): step_id = active_steps[-1].id else: raise Exception("No step_id provided and no active step found.") @@ -525,7 +517,9 @@ def create_attachment( response = self.make_gql_call(description, query, variables) return process_response(response) - def update_attachment(self, id: str, update_params: AttachmentUpload) -> "Attachment": + def update_attachment( + self, id: str, update_params: AttachmentUpload + ) -> "Attachment": return self.gql_helper(*update_attachment_helper(id, update_params)) def get_attachment(self, id: str) -> Optional["Attachment"]: @@ -728,7 +722,7 @@ def create_experiment_item( ################################################################################## # Dataset Item APIs # ################################################################################## - + def create_dataset_item( self, dataset_id: str, @@ -770,7 +764,9 @@ def get_or_create_prompt_lineage( return self.gql_helper(*create_prompt_lineage_helper(name, description)) @deprecated("Use get_or_create_prompt_lineage instead") - def create_prompt_lineage(self, name: str, description: Optional[str] = None) -> Dict: + def create_prompt_lineage( + self, name: str, description: Optional[str] = None + ) -> Dict: return self.get_or_create_prompt_lineage(name, description) def get_or_create_prompt( @@ -804,15 +800,26 @@ def get_prompt( if not (id or name): raise ValueError("At least the `id` or the `name` must be provided.") - get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( - api=self,id=id, name=name, version=version, cache=self.cache + ( + get_prompt_query, + description, + variables, + process_response, + timeout, + cached_prompt, + ) = get_prompt_helper( + api=self, id=id, name=name, version=version, cache=self.cache ) try: if id: - prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) + prompt = self.gql_helper( + get_prompt_query, description, variables, process_response, timeout + ) elif name: - prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) + prompt = self.gql_helper( + get_prompt_query, description, variables, process_response, timeout + ) return prompt @@ -820,7 +827,7 @@ def get_prompt( if cached_prompt: logger.warning("Failed to get prompt from API, returning cached prompt") return cached_prompt - + raise e def create_prompt_variant( diff --git a/literalai/cache/shared_cache.py b/literalai/cache/shared_cache.py index b356f32a..6193262a 100644 --- a/literalai/cache/shared_cache.py +++ b/literalai/cache/shared_cache.py @@ -6,6 +6,7 @@ class SharedCache: Singleton cache for storing data. Only one instance will exist regardless of how many times it's instantiated. """ + _instance = None _cache: dict[str, Any] @@ -39,4 +40,3 @@ def clear(self) -> None: Clears all cached values. """ self._cache.clear() - diff --git a/literalai/client.py b/literalai/client.py index bda138cf..586b41db 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -1,10 +1,11 @@ +import io import json import os -from traceloop.sdk import Traceloop +from contextlib import redirect_stdout from typing import Any, Dict, List, Optional, Union + +from traceloop.sdk import Traceloop from typing_extensions import deprecated -import io -from contextlib import redirect_stdout from literalai.api import AsyncLiteralAPI, LiteralAPI from literalai.callback.langchain_callback import get_langchain_callback @@ -343,7 +344,6 @@ def start_step( if hasattr(self, "global_metadata") and self.global_metadata: step.metadata = step.metadata or {} step.metadata.update(self.global_metadata) - step.start() return step @@ -378,7 +378,8 @@ def set_properties( ): thread = active_thread_var.get() root_run = active_root_run_var.get() - parent = active_steps_var.get()[-1] if active_steps_var.get() else None + active_steps = active_steps_var.get() + parent = active_steps[-1] if active_steps else None Traceloop.set_association_properties( { @@ -396,7 +397,7 @@ def reset_context(self): """ Resets the context, forgetting active steps & setting current thread to None. """ - active_steps_var.set([]) + active_steps_var.set(None) active_thread_var.set(None) active_root_run_var.set(None) diff --git a/literalai/context.py b/literalai/context.py index 2c790e0a..a16d9582 100644 --- a/literalai/context.py +++ b/literalai/context.py @@ -5,7 +5,7 @@ from literalai.observability.step import Step from literalai.observability.thread import Thread -active_steps_var = ContextVar[List["Step"]]("active_steps", default=[]) +active_steps_var = ContextVar[Optional[List["Step"]]]("active_steps", default=None) active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None) active_root_run_var = ContextVar[Optional["Step"]]("active_root_run_var", default=None) diff --git a/literalai/evaluation/dataset.py b/literalai/evaluation/dataset.py index 61789d67..96abb815 100644 --- a/literalai/evaluation/dataset.py +++ b/literalai/evaluation/dataset.py @@ -29,6 +29,7 @@ class Dataset(Utils): """ A dataset of items, each item representing an ideal scenario to run experiments on. """ + api: "LiteralAPI" id: str created_at: str diff --git a/literalai/evaluation/dataset_experiment.py b/literalai/evaluation/dataset_experiment.py index e428036e..2f28b43b 100644 --- a/literalai/evaluation/dataset_experiment.py +++ b/literalai/evaluation/dataset_experiment.py @@ -24,6 +24,7 @@ class DatasetExperimentItem(Utils): """ An item of a `DatasetExperiment`: it may be linked to a `DatasetItem`. """ + id: str dataset_experiment_id: str dataset_item_id: Optional[str] @@ -71,6 +72,7 @@ class DatasetExperiment(Utils): """ An experiment, linked or not to a `Dataset`. """ + api: "LiteralAPI" id: str created_at: str diff --git a/literalai/evaluation/dataset_item.py b/literalai/evaluation/dataset_item.py index a994b19f..ea90ce1b 100644 --- a/literalai/evaluation/dataset_item.py +++ b/literalai/evaluation/dataset_item.py @@ -19,6 +19,7 @@ class DatasetItem(Utils): """ A `Dataset` item, containing `input`, `expectedOutput` and `metadata`. """ + id: str created_at: str dataset_id: str diff --git a/literalai/exporter.py b/literalai/exporter.py index 20961163..17f00571 100644 --- a/literalai/exporter.py +++ b/literalai/exporter.py @@ -1,10 +1,10 @@ -from datetime import datetime, timezone import json -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult -from typing import Dict, List, Optional, Sequence, Union, cast import logging +from datetime import datetime, timezone +from typing import Dict, List, Optional, Sequence, Union, cast +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult from literalai.event_processor import EventProcessor from literalai.helper import utc_now @@ -116,10 +116,16 @@ def _create_step_from_span(self, span: ReadableSpan) -> Step: ) messages = self._extract_messages(cast(Dict, attributes)) if is_chat else [] - message_completions = self._extract_messages(cast(Dict, attributes), "gen_ai.completion.") if is_chat else [] + message_completions = ( + self._extract_messages(cast(Dict, attributes), "gen_ai.completion.") + if is_chat + else [] + ) message_completion = message_completions[-1] if message_completions else None - previous_messages = messages + message_completions[:-1] if message_completions else messages + previous_messages = ( + messages + message_completions[:-1] if message_completions else messages + ) generation_content = { "duration": duration, diff --git a/literalai/instrumentation/llamaindex/__init__.py b/literalai/instrumentation/llamaindex/__init__.py index 5379f090..102ce1d8 100644 --- a/literalai/instrumentation/llamaindex/__init__.py +++ b/literalai/instrumentation/llamaindex/__init__.py @@ -1,11 +1,12 @@ -from literalai.client import LiteralClient from llama_index.core.instrumentation import get_dispatcher +from literalai.client import LiteralClient from literalai.instrumentation.llamaindex.event_handler import LiteralEventHandler from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler is_llamaindex_instrumented = False + def instrument_llamaindex(client: "LiteralClient"): """ Instruments LlamaIndex to automatically send logs to Literal AI. @@ -13,7 +14,7 @@ def instrument_llamaindex(client: "LiteralClient"): global is_llamaindex_instrumented if is_llamaindex_instrumented: return - + root_dispatcher = get_dispatcher() span_handler = LiteralSpanHandler() @@ -23,5 +24,5 @@ def instrument_llamaindex(client: "LiteralClient"): literal_client=client, llama_index_span_handler=span_handler ) root_dispatcher.add_event_handler(event_handler) - + is_llamaindex_instrumented = True diff --git a/literalai/instrumentation/openai.py b/literalai/instrumentation/openai.py index b1554a4d..38fd02f7 100644 --- a/literalai/instrumentation/openai.py +++ b/literalai/instrumentation/openai.py @@ -11,9 +11,9 @@ from literalai.context import active_steps_var, active_thread_var from literalai.helper import ensure_values_serializable from literalai.observability.generation import ( - GenerationMessage, - CompletionGeneration, ChatGeneration, + CompletionGeneration, + GenerationMessage, GenerationType, ) from literalai.wrappers import AfterContext, BeforeContext, wrap_all diff --git a/literalai/my_types.py b/literalai/my_types.py index f45153dc..3a8d36a2 100644 --- a/literalai/my_types.py +++ b/literalai/my_types.py @@ -1,11 +1,10 @@ import json import uuid -from typing import Any, Dict, Generic, List, Literal, Optional, Protocol, TypeVar from abc import abstractmethod - -from typing_extensions import TypedDict +from typing import Any, Dict, Generic, List, Literal, Optional, Protocol, TypeVar from pydantic.dataclasses import Field, dataclass +from typing_extensions import TypedDict Environment = Literal["dev", "staging", "prod", "experiment"] @@ -41,7 +40,9 @@ def from_dict(cls, page_info_dict: Dict) -> "PageInfo": start_cursor = page_info_dict.get("startCursor", None) end_cursor = page_info_dict.get("endCursor", None) return cls( - has_next_page=has_next_page, start_cursor=start_cursor, end_cursor=end_cursor + has_next_page=has_next_page, + start_cursor=start_cursor, + end_cursor=end_cursor, ) diff --git a/literalai/observability/message.py b/literalai/observability/message.py index 511cdf2b..22f0470d 100644 --- a/literalai/observability/message.py +++ b/literalai/observability/message.py @@ -4,10 +4,10 @@ if TYPE_CHECKING: from literalai.event_processor import EventProcessor -from literalai.context import active_steps_var, active_thread_var, active_root_run_var +from literalai.context import active_root_run_var, active_steps_var, active_thread_var from literalai.helper import utc_now from literalai.my_types import Utils -from literalai.observability.step import MessageStepType, StepDict, Score, Attachment +from literalai.observability.step import Attachment, MessageStepType, Score, StepDict class Message(Utils): @@ -73,7 +73,7 @@ def __init__( def end(self): active_steps = active_steps_var.get() - if len(active_steps) > 0: + if active_steps: parent_step = active_steps[-1] if not self.parent_id: self.parent_id = parent_step.id diff --git a/literalai/observability/step.py b/literalai/observability/step.py index 27d3c51b..91c78080 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -381,7 +381,7 @@ def __init__( def start(self): active_steps = active_steps_var.get() - if len(active_steps) > 0: + if active_steps: parent_step = active_steps[-1] if not self.parent_id: self.parent_id = parent_step.id @@ -398,8 +398,8 @@ def start(self): if active_root_run := active_root_run_var.get(): self.root_run_id = active_root_run.id - active_steps.append(self) - active_steps_var.set(active_steps) + new_steps = (active_steps_var.get() or []) + [self] + active_steps_var.set(new_steps) def end(self): self.end_time = utc_now() @@ -412,8 +412,8 @@ def end(self): raise Exception("Step must be started before ending.") # Remove step from active steps - active_steps.remove(self) - active_steps_var.set(active_steps) + new_steps = [s for s in active_steps if s.id != self.id] + active_steps_var.set(new_steps) if self.processor is None: raise Exception( @@ -507,7 +507,11 @@ def __call__(self, func): self.client, func=func, name=self.step_name, - ctx_manager=self, + type=self.step_type, + id=self.id, + parent_id=self.parent_id, + thread_id=self.thread_id, + root_run_id=self.root_run_id, ) async def __aenter__(self): @@ -611,35 +615,38 @@ def step_decorator( parent_id: Optional[str] = None, thread_id: Optional[str] = None, root_run_id: Optional[str] = None, - ctx_manager: Optional[StepContextManager] = None, **decorator_kwargs, ): if not name: name = func.__name__ - if not ctx_manager: - ctx_manager = StepContextManager( - client=client, - type=type, - name=name, - id=id, - parent_id=parent_id, - thread_id=thread_id, - root_run_id=root_run_id, - **decorator_kwargs, - ) - else: - ctx_manager.step_name = name + # Handle async decorator if inspect.iscoroutinefunction(func): @wraps(func) async def async_wrapper(*args, **kwargs): - with ctx_manager as step: + # Create context manager here, when the function is actually called + ctx = StepContextManager( + client=client, + type=type, + name=name, + id=id, + parent_id=parent_id, + thread_id=thread_id, + root_run_id=root_run_id, + **decorator_kwargs, + ) + + ctx.step_name = name + + async with ctx as step: try: step.input = flatten_args_kwargs(func, *args, **kwargs) except Exception: pass + result = await func(*args, **kwargs) + try: if step.output is None: if isinstance(result, dict): @@ -648,6 +655,7 @@ async def async_wrapper(*args, **kwargs): step.output = {"content": deepcopy(result)} except Exception: pass + return result return async_wrapper @@ -655,12 +663,28 @@ async def async_wrapper(*args, **kwargs): # Handle sync decorator @wraps(func) def sync_wrapper(*args, **kwargs): - with ctx_manager as step: + # Create context manager here, when the function is actually called + ctx = StepContextManager( + client=client, + type=type, + name=name, + id=id, + parent_id=parent_id, + thread_id=thread_id, + root_run_id=root_run_id, + **decorator_kwargs, + ) + + ctx.step_name = name + + with ctx as step: try: step.input = flatten_args_kwargs(func, *args, **kwargs) except Exception: pass + result = func(*args, **kwargs) + try: if step.output is None: if isinstance(result, dict): diff --git a/literalai/version.py b/literalai/version.py index 361ce364..0ebbb206 100644 --- a/literalai/version.py +++ b/literalai/version.py @@ -1 +1 @@ -__version__ = "0.1.106" +__version__ = "0.1.107" diff --git a/literalai/wrappers.py b/literalai/wrappers.py index 9148f4a8..9b07a45d 100644 --- a/literalai/wrappers.py +++ b/literalai/wrappers.py @@ -6,8 +6,8 @@ from literalai.context import active_steps_var if TYPE_CHECKING: + from literalai.observability.generation import ChatGeneration, CompletionGeneration from literalai.observability.step import Step - from literalai.observability.generation import CompletionGeneration, ChatGeneration class BeforeContext(TypedDict): @@ -25,7 +25,7 @@ class AfterContext(TypedDict): def remove_literalai_args(kargs): - '''Remove argument prefixed with "literalai_" from kwargs and return them in a separate dict''' + """Remove argument prefixed with "literalai_" from kwargs and return them in a separate dict""" largs = {} for key in list(kargs.keys()): if key.startswith("literalai_"): @@ -33,8 +33,9 @@ def remove_literalai_args(kargs): largs[key] = value return largs + def restore_literalai_args(kargs, largs): - '''Reverse the effect of remove_literalai_args by merging the literal arguments into kwargs''' + """Reverse the effect of remove_literalai_args by merging the literal arguments into kwargs""" for key in list(largs.keys()): kargs[key] = largs[key] diff --git a/setup.py b/setup.py index afef743a..864d9b3a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="literalai", - version="0.1.106", # update version in literalai/version.py + version="0.1.107", # update version in literalai/version.py description="An SDK for observability in Python applications", long_description=open("README.md").read(), long_description_content_type="text/markdown", From 3ae371a45a8ba12fcf5a8fec09d7507bd42e48d1 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 18 Feb 2025 12:04:30 -0800 Subject: [PATCH 2/2] fix: ci --- literalai/observability/step.py | 1 + tests/e2e/test_e2e.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/literalai/observability/step.py b/literalai/observability/step.py index 91c78080..39d7a0cf 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -512,6 +512,7 @@ def __call__(self, func): parent_id=self.parent_id, thread_id=self.thread_id, root_run_id=self.root_run_id, + **self.kwargs, ) async def __aenter__(self): diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 426670ba..2546e48f 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -276,12 +276,12 @@ async def test_ingestion( step.metadata = {"foo": "bar"} assert async_client.event_processor.event_queue._qsize() == 0 stack = active_steps_var.get() - assert len(stack) == 1 + assert stack is not None and len(stack) == 1 assert async_client.event_processor.event_queue._qsize() == 1 stack = active_steps_var.get() - assert len(stack) == 0 + assert stack is not None and len(stack) == 0 @pytest.mark.timeout(5) async def test_thread_decorator( @@ -666,14 +666,14 @@ async def test_prompt(self, async_client: AsyncLiteralClient): async def test_prompt_cache(self, async_client: AsyncLiteralClient): prompt = await async_client.api.get_prompt(name="Default", version=0) assert prompt is not None - + original_key = async_client.api.api_key async_client.api.api_key = "invalid-api-key" - + cached_prompt = await async_client.api.get_prompt(name="Default", version=0) assert cached_prompt is not None assert cached_prompt.id == prompt.id - + async_client.api.api_key = original_key @pytest.mark.timeout(5)