From b1cb490b5ceeda78ce69a1c52efc94e4825e2112 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 25 Mar 2025 18:52:07 +0100 Subject: [PATCH 1/2] feat: add reasoning token support to lc --- literalai/callback/langchain_callback.py | 71 +++++++++++++++++++++--- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 119c340..1e2d4ae 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -92,7 +92,33 @@ def _convert_message_dict( if function_call: msg["function_call"] = function_call else: - msg["content"] = kwargs.get("content", "") + content = kwargs.get("content") + if isinstance(content, list): + tool_calls = [] + content_parts = [] + for item in content: + if item.get("type") == "tool_use": + tool_calls.append( + { + "id": item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("input"), + }, + } + ) + elif item.get("type") == "text": + content_parts.append( + {"type": "text", "text": item.get("text")} + ) + + if tool_calls: + msg["tool_calls"] = tool_calls + if content_parts: + msg["content"] = content_parts # type: ignore + else: + msg["content"] = content # type: ignore if tool_calls: msg["tool_calls"] = tool_calls @@ -123,7 +149,32 @@ def _convert_message( if function_call: msg["function_call"] = function_call else: - msg["content"] = message.content # type: ignore + if isinstance(message.content, list): + tool_calls = [] + content_parts = [] + for item in message.content: + if item.get("type") == "tool_use": + tool_calls.append( + { + "id": item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("input"), + }, + } + ) + elif item.get("type") == "text": + content_parts.append( + {"type": "text", "text": item.get("text")} + ) + + if tool_calls: + msg["tool_calls"] = tool_calls + if content_parts: + msg["content"] = content_parts # type: ignore + else: + msg["content"] = message.content # type: ignore if tool_calls: msg["tool_calls"] = tool_calls @@ -201,7 +252,12 @@ def _build_llm_settings( {"type": "function", "function": f} for f in settings["functions"] ] if "tools" in settings: - tools = settings["tools"] + tools = [ + {"type": "function", "function": t} + if t.get("type") != "function" + else t + for t in settings["tools"] + ] return provider, model, tools, settings DEFAULT_TO_IGNORE = [ @@ -411,7 +467,9 @@ def _start_trace(self, run: Run) -> None: ) step.tags = run.tags step.metadata = run.metadata - step.input = self.process_content(run.inputs) + + if step.type != "llm": + step.input = self.process_content(run.inputs) self.steps[str(run.id)] = step @@ -484,7 +542,6 @@ def _on_run_update(self, run: Run) -> None: if v is not None } - current_step.output = message_completion else: completion_start = self.completion_generations[str(run.id)] duration = time.time() - completion_start["start"] @@ -509,7 +566,6 @@ def _on_run_update(self, run: Run) -> None: output_token_count=usage_metadata.get("output_tokens"), token_count=usage_metadata.get("total_tokens"), ) - current_step.output = {"content": completion} if current_step: if current_step.metadata is None: @@ -521,7 +577,8 @@ def _on_run_update(self, run: Run) -> None: outputs = run.outputs or {} if current_step: - current_step.output = self.process_content(outputs) + if current_step.type != "llm": + current_step.output = self.process_content(outputs) current_step.end() def _on_error(self, error: BaseException, *, run_id: "UUID", **kwargs: Any): From 561240ccbfb160f779024ba97d74735d9eedcf0c Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 25 Mar 2025 19:13:30 +0100 Subject: [PATCH 2/2] fix: mypy --- literalai/callback/langchain_callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 1e2d4ae..81045eb 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -153,6 +153,8 @@ def _convert_message( tool_calls = [] content_parts = [] for item in message.content: + if isinstance(item, str): + continue if item.get("type") == "tool_use": tool_calls.append( {