diff --git a/literalai/client.py b/literalai/client.py index 7628739..c3743cd 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -46,6 +46,7 @@ class BaseLiteralClient: """ api: Union[LiteralAPI, AsyncLiteralAPI] + global_metadata: Optional[Dict[str, str]] = None def __init__( self, @@ -55,6 +56,7 @@ def __init__( url: Optional[str] = None, environment: Optional[Environment] = None, disabled: bool = False, + release: Optional[str] = None, ): if not api_key: api_key = os.getenv("LITERAL_API_KEY", None) @@ -77,6 +79,9 @@ def __init__( disabled=self.disabled, ) + if release and release.strip(): + self.global_metadata = {"release": release.strip()} + def to_sync(self) -> "LiteralClient": if isinstance(self.api, AsyncLiteralAPI): return LiteralClient( @@ -223,6 +228,9 @@ def message( metadata: Dict = {}, root_run_id: Optional[str] = None, ): + if hasattr(self, "global_metadata") and self.global_metadata: + metadata.update(self.global_metadata) + step = Message( name=name, id=id, @@ -313,6 +321,11 @@ def start_step( root_run_id=root_run_id, **kwargs, ) + + if hasattr(self, "global_metadata") and self.global_metadata: + step.metadata = step.metadata or {} + step.metadata.update(self.global_metadata) + step.start() return step @@ -373,6 +386,7 @@ def __init__( url: Optional[str] = None, environment: Optional[Environment] = None, disabled: bool = False, + release: Optional[str] = None, ): super().__init__( batch_size=batch_size, @@ -381,6 +395,7 @@ def __init__( url=url, disabled=disabled, environment=environment, + release=release, ) def flush(self): @@ -407,6 +422,7 @@ def __init__( url: Optional[str] = None, environment: Optional[Environment] = None, disabled: bool = False, + release: Optional[str] = None, ): super().__init__( batch_size=batch_size, @@ -415,6 +431,7 @@ def __init__( url=url, disabled=disabled, environment=environment, + release=release, ) async def flush(self): diff --git a/literalai/observability/step.py b/literalai/observability/step.py index f62eeae..d240a8e 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -379,6 +379,7 @@ def __init__( processor: Optional["EventProcessor"] = None, tags: Optional[List[str]] = None, root_run_id: Optional[str] = None, + metadata: Optional[Dict] = None, ): from time import sleep @@ -400,6 +401,8 @@ def __init__( self.parent_id = parent_id self.tags = tags + if metadata: + self.metadata = metadata def start(self): active_steps = active_steps_var.get() @@ -540,6 +543,7 @@ async def __aenter__(self): parent_id=self.parent_id, thread_id=self.thread_id, root_run_id=self.root_run_id, + metadata=self.kwargs.get("metadata", None), **self.kwargs, ) @@ -566,6 +570,7 @@ def __enter__(self) -> Step: parent_id=self.parent_id, thread_id=self.thread_id, root_run_id=self.root_run_id, + metadata=self.kwargs.get("metadata", None), **self.kwargs, ) diff --git a/literalai/observability/thread.py b/literalai/observability/thread.py index 85fb4c9..87a22d4 100644 --- a/literalai/observability/thread.py +++ b/literalai/observability/thread.py @@ -102,11 +102,11 @@ def to_dict(self) -> ThreadDict: "tags": self.tags, "name": self.name, "steps": [step.to_dict() for step in self.steps] if self.steps else [], - "participant": UserDict( - id=self.participant_id, identifier=self.participant_identifier - ) - if self.participant_id - else UserDict(), + "participant": ( + UserDict(id=self.participant_id, identifier=self.participant_identifier) + if self.participant_id + else UserDict() + ), "createdAt": getattr(self, "created_at", None), } @@ -163,7 +163,12 @@ def upsert(self): "id": thread_data["id"], "name": thread_data["name"], } - if metadata := thread_data.get("metadata"): + + metadata = { + **(self.client.global_metadata or {}), + **(thread_data.get("metadata") or {}), + } + if metadata: thread_data_to_upsert["metadata"] = metadata if tags := thread_data.get("tags"): thread_data_to_upsert["tags"] = tags diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index c858ae6..992c3a5 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -315,6 +315,36 @@ async def a_thread_decorated(): id = await a_thread_decorated() await assert_delete(id) + @pytest.mark.timeout(5) + async def test_release(self): + url = os.getenv("LITERAL_API_URL", None) + api_key = os.getenv("LITERAL_API_KEY", None) + async_client = AsyncLiteralClient( + batch_size=5, url=url, api_key=api_key, release="test" + ) + + async def assert_delete(thread_id: str, step_id: str): + await async_client.flush() + assert await async_client.api.delete_step(step_id) is True + assert await async_client.api.delete_thread(thread_id) is True + + @async_client.thread + def thread_decorated(): + @async_client.step(name="foo", type="llm") + def step_decorated(): + t = async_client.get_current_thread() + s = async_client.get_current_step() + assert s is not None + assert s.name == "foo" + assert s.type == "llm" + assert s.metadata.get("release") == "test" + return t.id, s.id + + return step_decorated() + + thread_id, step_id = thread_decorated() + await assert_delete(thread_id, step_id) + @pytest.mark.timeout(5) async def test_step_decorator( self, client: LiteralClient, async_client: AsyncLiteralClient @@ -600,7 +630,9 @@ async def test_prompt(self, async_client: AsyncLiteralClient): assert prompt.name == "Default" assert prompt.version == 0 assert prompt.provider == "openai" - assert prompt.url.endswith(f"projects/{project}/playground?name=Default&version=0") + assert prompt.url.endswith( + f"projects/{project}/playground?name=Default&version=0" + ) prompt = await async_client.api.get_prompt(id=prompt.id, version=0) assert prompt is not None