diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 119c340..81045eb 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -92,7 +92,33 @@ def _convert_message_dict( if function_call: msg["function_call"] = function_call else: - msg["content"] = kwargs.get("content", "") + content = kwargs.get("content") + if isinstance(content, list): + tool_calls = [] + content_parts = [] + for item in content: + if item.get("type") == "tool_use": + tool_calls.append( + { + "id": item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("input"), + }, + } + ) + elif item.get("type") == "text": + content_parts.append( + {"type": "text", "text": item.get("text")} + ) + + if tool_calls: + msg["tool_calls"] = tool_calls + if content_parts: + msg["content"] = content_parts # type: ignore + else: + msg["content"] = content # type: ignore if tool_calls: msg["tool_calls"] = tool_calls @@ -123,7 +149,34 @@ def _convert_message( if function_call: msg["function_call"] = function_call else: - msg["content"] = message.content # type: ignore + if isinstance(message.content, list): + tool_calls = [] + content_parts = [] + for item in message.content: + if isinstance(item, str): + continue + if item.get("type") == "tool_use": + tool_calls.append( + { + "id": item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("input"), + }, + } + ) + elif item.get("type") == "text": + content_parts.append( + {"type": "text", "text": item.get("text")} + ) + + if tool_calls: + msg["tool_calls"] = tool_calls + if content_parts: + msg["content"] = content_parts # type: ignore + else: + msg["content"] = message.content # type: ignore if tool_calls: msg["tool_calls"] = tool_calls @@ -201,7 +254,12 @@ def _build_llm_settings( {"type": "function", "function": f} for f in settings["functions"] ] if "tools" in settings: - tools = settings["tools"] + tools = [ + {"type": "function", "function": t} + if t.get("type") != "function" + else t + for t in settings["tools"] + ] return provider, model, tools, settings DEFAULT_TO_IGNORE = [ @@ -411,7 +469,9 @@ def _start_trace(self, run: Run) -> None: ) step.tags = run.tags step.metadata = run.metadata - step.input = self.process_content(run.inputs) + + if step.type != "llm": + step.input = self.process_content(run.inputs) self.steps[str(run.id)] = step @@ -484,7 +544,6 @@ def _on_run_update(self, run: Run) -> None: if v is not None } - current_step.output = message_completion else: completion_start = self.completion_generations[str(run.id)] duration = time.time() - completion_start["start"] @@ -509,7 +568,6 @@ def _on_run_update(self, run: Run) -> None: output_token_count=usage_metadata.get("output_tokens"), token_count=usage_metadata.get("total_tokens"), ) - current_step.output = {"content": completion} if current_step: if current_step.metadata is None: @@ -521,7 +579,8 @@ def _on_run_update(self, run: Run) -> None: outputs = run.outputs or {} if current_step: - current_step.output = self.process_content(outputs) + if current_step.type != "llm": + current_step.output = self.process_content(outputs) current_step.end() def _on_error(self, error: BaseException, *, run_id: "UUID", **kwargs: Any):