diff --git a/literalai/client.py b/literalai/client.py index bda138c..c9a13af 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -351,7 +351,7 @@ def get_current_step(self): """ Gets the current step from the context. """ - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: return active_steps[-1] else: diff --git a/literalai/context.py b/literalai/context.py index 2c790e0..1e8f6a3 100644 --- a/literalai/context.py +++ b/literalai/context.py @@ -5,7 +5,7 @@ from literalai.observability.step import Step from literalai.observability.thread import Thread -active_steps_var = ContextVar[List["Step"]]("active_steps", default=[]) +active_steps_var = ContextVar[List["Step"]]("active_steps", default=None) active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None) active_root_run_var = ContextVar[Optional["Step"]]("active_root_run_var", default=None) diff --git a/literalai/instrumentation/mistralai.py b/literalai/instrumentation/mistralai.py index 881256f..b9ea378 100644 --- a/literalai/instrumentation/mistralai.py +++ b/literalai/instrumentation/mistralai.py @@ -202,7 +202,7 @@ def update_step_after( def before_wrapper(metadata: Dict): def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): @@ -225,7 +225,7 @@ def before(context: BeforeContext, *args, **kwargs): def async_before_wrapper(metadata: Dict): async def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): diff --git a/literalai/instrumentation/openai.py b/literalai/instrumentation/openai.py index b1554a4..98c730e 100644 --- a/literalai/instrumentation/openai.py +++ b/literalai/instrumentation/openai.py @@ -186,7 +186,7 @@ def update_step_after( def before_wrapper(metadata: Dict): def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): @@ -209,7 +209,7 @@ def before(context: BeforeContext, *args, **kwargs): def async_before_wrapper(metadata: Dict): async def before(context: BeforeContext, *args, **kwargs): active_thread = active_thread_var.get() - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) generation = init_generation(metadata["type"], kwargs) if (active_thread or active_steps) and not callable(on_new_generation): diff --git a/literalai/observability/message.py b/literalai/observability/message.py index 511cdf2..7807db7 100644 --- a/literalai/observability/message.py +++ b/literalai/observability/message.py @@ -71,7 +71,7 @@ def __init__( self.parent_id = parent_id def end(self): - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if len(active_steps) > 0: parent_step = active_steps[-1] diff --git a/literalai/observability/step.py b/literalai/observability/step.py index 27d3c51..4847357 100644 --- a/literalai/observability/step.py +++ b/literalai/observability/step.py @@ -380,7 +380,7 @@ def __init__( setattr(self, key, value) def start(self): - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if len(active_steps) > 0: parent_step = active_steps[-1] if not self.parent_id: @@ -405,7 +405,7 @@ def end(self): self.end_time = utc_now() # Update active steps - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) # Check if step is active if self not in active_steps: diff --git a/literalai/wrappers.py b/literalai/wrappers.py index 9148f4a..2394e4f 100644 --- a/literalai/wrappers.py +++ b/literalai/wrappers.py @@ -53,7 +53,7 @@ def wrapped(*args, **kwargs): try: result = original_func(*args, **kwargs) except Exception as e: - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: current_step = active_steps[-1] current_step.error = str(e) @@ -87,7 +87,7 @@ async def wrapped(*args, **kwargs): try: result = await original_func(*args, **kwargs) except Exception as e: - active_steps = active_steps_var.get() + active_steps = active_steps_var.get([]) if active_steps and len(active_steps) > 0: current_step = active_steps[-1] current_step.error = str(e) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 426670b..170da32 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -275,12 +275,12 @@ async def test_ingestion( with async_client.step(name="test_ingestion") as step: step.metadata = {"foo": "bar"} assert async_client.event_processor.event_queue._qsize() == 0 - stack = active_steps_var.get() + stack = active_steps_var.get([]) assert len(stack) == 1 assert async_client.event_processor.event_queue._qsize() == 1 - stack = active_steps_var.get() + stack = active_steps_var.get([]) assert len(stack) == 0 @pytest.mark.timeout(5)