From cd823f73b7f749e6b8a60b39a65deb39fb61a1a1 Mon Sep 17 00:00:00 2001 From: Emerson Queiroz Date: Wed, 22 Jan 2025 11:28:48 +0100 Subject: [PATCH] Fixing 'active_steps' getting mixed in async operations --- literalai/client.py | 2 +- literalai/context.py | 2 +- literalai/instrumentation/mistralai.py | 4 ++-- literalai/instrumentation/openai.py | 4 ++-- literalai/observability/message.py | 2 +- literalai/observability/step.py | 4 ++-- literalai/wrappers.py | 4 ++-- tests/e2e/test_e2e.py | 4 ++-- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/literalai/client.py b/literalai/client.py index bda138cf..c9a13af8 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -351,7 +351,7 @@ def get_current_step(self): """ Gets the current step from the context. """ - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: return active_steps[-1] else: diff --git a/literalai/context.py b/literalai/context.py index 2c790e0a..1e8f6a38 100644 --- a/literalai/context.py +++ b/literalai/context.py @@ -5,7 +5,7 @@ from literalai.observability.step import Step from literalai.observability.thread import Thread -active_steps_var = ContextVar[List["Step"]]("active_steps", default=[]) +active_steps_var = ContextVar[List["Step"]]("active_steps", default=None) active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None) active_root_run_var = ContextVar[Optional["Step"]]("active_root_run_var", default=None) diff --git a/literalai/instrumentation/mistralai.py b/literalai/instrumentation/mistralai.py index 881256fb..b9ea378d 100644 --- a/literalai/instrumentation/mistralai.py +++ b/literalai/instrumentation/mistralai.py @@ -202,7 +202,7 @@ def update_step_after( def before_wrapper(metadata: Dict): def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): @@ -225,7 +225,7 @@ def before(context: BeforeContext, *args, **kwargs): def async_before_wrapper(metadata: Dict): async def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): diff --git a/literalai/instrumentation/openai.py b/literalai/instrumentation/openai.py index b1554a4d..98c730e4 100644 --- a/literalai/instrumentation/openai.py +++ b/literalai/instrumentation/openai.py @@ -186,7 +186,7 @@ def update_step_after( def before_wrapper(metadata: Dict): def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): @@ -209,7 +209,7 @@ def before(context: BeforeContext, *args, **kwargs): def async_before_wrapper(metadata: Dict): async def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): diff --git a/literalai/observability/message.py b/literalai/observability/message.py index 511cdf2b..7807db76 100644 --- a/literalai/observability/message.py +++ b/literalai/observability/message.py @@ -71,7 +71,7 @@ def __init__( self.parent_id = parent_id def end(self): - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if len(active_steps) > 0: parent_step = active_steps[-1] diff --git a/literalai/observability/step.py b/literalai/observability/step.py index 27d3c51b..48473579 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -380,7 +380,7 @@ def __init__( setattr(self, key, value) def start(self): - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if len(active_steps) > 0: parent_step = active_steps[-1] if not self.parent_id: @@ -405,7 +405,7 @@ def end(self): self.end_time = utc_now() # Update active steps - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) # Check if step is active if self not in active_steps: diff --git a/literalai/wrappers.py b/literalai/wrappers.py index 9148f4a8..2394e4fb 100644 --- a/literalai/wrappers.py +++ b/literalai/wrappers.py @@ -53,7 +53,7 @@ def wrapped(*args, **kwargs): try: result = original_func(*args, **kwargs) except Exception as e: - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: current_step = active_steps[-1] current_step.error = str(e) @@ -87,7 +87,7 @@ async def wrapped(*args, **kwargs): try: result = await original_func(*args, **kwargs) except Exception as e: - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: current_step = active_steps[-1] current_step.error = str(e) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 426670ba..170da32d 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -275,12 +275,12 @@ async def test_ingestion( with async_client.step(name="test_ingestion") as step: step.metadata = {"foo": "bar"} assert async_client.event_processor.event_queue._qsize() == 0 - stack = active_steps_var.get() + stack = active_steps_var.get([]) assert len(stack) == 1 assert async_client.event_processor.event_queue._qsize() == 1 - stack = active_steps_var.get() + stack = active_steps_var.get([]) assert len(stack) == 0 @pytest.mark.timeout(5)