diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 43dba05..170f22e 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -337,7 +337,8 @@ def _start_trace(self, run: Run) -> None: if run.run_type == "agent": step_type = "run" elif run.run_type == "chain": - pass + if not self.steps: + step_type = "run" elif run.run_type == "llm": step_type = "llm" elif run.run_type == "retriever": @@ -347,9 +348,6 @@ def _start_trace(self, run: Run) -> None: elif run.run_type == "embedding": step_type = "embedding" - if not self.steps and step_type != "llm": - step_type = "run" - step = self.client.start_step( id=str(run.id), name=run.name, type=step_type, parent_id=parent_id ) diff --git a/literalai/callback/llama_index_callback.py b/literalai/callback/llama_index_callback.py index 9b61633..11f74ac 100644 --- a/literalai/callback/llama_index_callback.py +++ b/literalai/callback/llama_index_callback.py @@ -55,7 +55,6 @@ def __init__( event_ends_to_ignore=event_ends_to_ignore, ) self.client = client - self.is_pristine = True self.steps = {} @@ -91,12 +90,6 @@ def on_event_start( else: return event_id - step_type = ( - "run" if self.is_pristine and step_type != "llm" else "undefined" - ) - - self.is_pristine = False - step = self.client.start_step( name=event_type.value, type=step_type, diff --git a/tests/e2e/test_mistralai.py b/tests/e2e/test_mistralai.py index 5b5f611..b858bcb 100644 --- a/tests/e2e/test_mistralai.py +++ b/tests/e2e/test_mistralai.py @@ -1,12 +1,13 @@ -from asyncio import sleep -from literalai.client import LiteralClient -import pytest import os -from pytest_httpx import HTTPXMock import urllib.parse -from mistralai.client import MistralClient +from asyncio import sleep + +import pytest from mistralai.async_client import MistralAsyncClient +from mistralai.client import MistralClient +from pytest_httpx import HTTPXMock +from literalai.client import LiteralClient from literalai.my_types import ChatGeneration, CompletionGeneration @@ -92,7 +93,7 @@ def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == ChatGeneration + assert type(step.generation) is ChatGeneration assert step.generation.settings is not None assert step.generation.model == "open-mistral-7b" @@ -148,7 +149,7 @@ def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == CompletionGeneration + assert type(step.generation) is CompletionGeneration assert step.generation.settings is not None assert step.generation.model == "codestral-2405" assert step.generation.completion == "2\n\n" @@ -187,7 +188,6 @@ async def test_async_chat(self, client: "LiteralClient", httpx_mock: "HTTPXMock" @client.thread async def main(): - # https://docs.mistral.ai/api/#operation/createChatCompletion await mai_client.chat( model="open-mistral-7b", @@ -213,7 +213,7 @@ async def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == ChatGeneration + assert type(step.generation) is ChatGeneration assert step.generation.settings is not None assert step.generation.model == "open-mistral-7b" @@ -271,7 +271,7 @@ async def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == CompletionGeneration + assert type(step.generation) is CompletionGeneration assert step.generation.settings is not None assert step.generation.model == "codestral-2405" assert step.generation.completion == "2\n\n" diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index e006d4d..c0aec52 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -99,7 +99,7 @@ def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == ChatGeneration + assert type(step.generation) is ChatGeneration assert step.generation.settings is not None assert step.generation.model == "gpt-3.5-turbo-0613" @@ -155,7 +155,7 @@ def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == CompletionGeneration + assert type(step.generation) is CompletionGeneration assert step.generation.settings is not None assert step.generation.model == "gpt-3.5-turbo" assert step.generation.completion == "\n\nThis is indeed a test" @@ -223,7 +223,7 @@ async def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == ChatGeneration + assert type(step.generation) is ChatGeneration assert step.generation.settings is not None assert step.generation.model == "gpt-3.5-turbo-0613" @@ -281,7 +281,7 @@ async def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == CompletionGeneration + assert type(step.generation) is CompletionGeneration assert step.generation.settings is not None assert step.generation.model == "gpt-3.5-turbo" assert step.generation.completion == "\n\nThis is indeed a test" @@ -346,7 +346,7 @@ def main(): assert step.type == "llm" assert step.generation is not None - assert type(step.generation) == CompletionGeneration + assert type(step.generation) is CompletionGeneration assert step.generation.settings is not None assert step.generation.model == "gpt-3.5-turbo" assert step.generation.completion == "\n\nThis is indeed a test"