diff --git a/src/llm/claude_client.py b/src/llm/claude_client.py index 68fa6eba..6e09c945 100644 --- a/src/llm/claude_client.py +++ b/src/llm/claude_client.py @@ -1,6 +1,8 @@ +from typing import Optional, List, Dict, Any from anthropic import Anthropic from src.config import Config +from src.llm.tools import AVAILABLE_TOOLS, Tool class Claude: def __init__(self): @@ -9,18 +11,47 @@ def __init__(self): self.client = Anthropic( api_key=api_key, ) + self.tool_schemas = [ + { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": { + name: { + "type": param.type, + "description": param.description, + **({"enum": param.enum} if param.enum else {}) + } + for name, param in tool.parameters.items() + }, + "required": tool.required + } + } + for tool in AVAILABLE_TOOLS + ] - def inference(self, model_id: str, prompt: str) -> str: - message = self.client.messages.create( - max_tokens=4096, - messages=[ + def inference( + self, + model_id: str, + prompt: str, + tools: Optional[List[Dict[str, Any]]] = None + ) -> str: + kwargs = { + "max_tokens": 4096, + "messages": [ { "role": "user", "content": prompt.strip(), } ], - model=model_id, - temperature=0 - ) + "model": model_id, + "temperature": 0 + } + + # Add tool schemas for Claude 3 models + if "claude-3" in model_id: + kwargs["tools"] = tools or self.tool_schemas + message = self.client.messages.create(**kwargs) return message.content[0].text diff --git a/src/llm/tools.py b/src/llm/tools.py new file mode 100644 index 00000000..4d53ab3a --- /dev/null +++ b/src/llm/tools.py @@ -0,0 +1,93 @@ +"""Tool schemas for Claude 3 function calling. + +This module defines the tool schemas for Claude 3's function calling capabilities. +Each tool follows Claude's function calling schema format. +""" + +from typing import Dict, List, Optional +from dataclasses import dataclass + + +@dataclass +class ToolParameter: + """Parameter definition for a tool.""" + type: str + description: str + enum: Optional[List[str]] = None + required: bool = True + + +@dataclass +class Tool: + """Tool definition following Claude's schema.""" + name: str + description: str + parameters: Dict[str, ToolParameter] + required: List[str] + + +# Core tool definitions +BROWSE_TOOL = Tool( + name="browse_web", + description="Browse a web page and extract its content", + parameters={ + "url": ToolParameter( + type="string", + description="The URL to browse" + ) + }, + required=["url"] +) + +READ_FILE_TOOL = Tool( + name="read_file", + description="Read the contents of a file", + parameters={ + "path": ToolParameter( + type="string", + description="The path to the file to read" + ) + }, + required=["path"] +) + +WRITE_FILE_TOOL = Tool( + name="write_file", + description="Write content to a file", + parameters={ + "path": ToolParameter( + type="string", + description="The path to write the file to" + ), + "content": ToolParameter( + type="string", + description="The content to write to the file" + ) + }, + required=["path", "content"] +) + +RUN_CODE_TOOL = Tool( + name="run_code", + description="Execute code in a sandboxed environment", + parameters={ + "code": ToolParameter( + type="string", + description="The code to execute" + ), + "language": ToolParameter( + type="string", + description="The programming language", + enum=["python", "javascript", "bash"] + ) + }, + required=["code", "language"] +) + +# List of all available tools +AVAILABLE_TOOLS = [ + BROWSE_TOOL, + READ_FILE_TOOL, + WRITE_FILE_TOOL, + RUN_CODE_TOOL +] diff --git a/tests/test_claude_tools.py b/tests/test_claude_tools.py new file mode 100644 index 00000000..6e417464 --- /dev/null +++ b/tests/test_claude_tools.py @@ -0,0 +1,143 @@ +"""Tests for Claude 3 tool use functionality.""" + +import pytest +from typing import Dict, Any + +from src.llm.tools import ( + Tool, + ToolParameter, + BROWSE_TOOL, + READ_FILE_TOOL, + WRITE_FILE_TOOL, + RUN_CODE_TOOL, + AVAILABLE_TOOLS +) +from src.llm.claude_client import Claude + + +def test_tool_parameter_creation(): + """Test creating tool parameters with various configurations.""" + param = ToolParameter( + type="string", + description="Test parameter", + enum=["a", "b", "c"], + required=True + ) + assert param.type == "string" + assert param.description == "Test parameter" + assert param.enum == ["a", "b", "c"] + assert param.required is True + + # Test without optional fields + basic_param = ToolParameter( + type="integer", + description="Basic parameter" + ) + assert basic_param.type == "integer" + assert basic_param.description == "Basic parameter" + assert basic_param.enum is None + assert basic_param.required is True + + +def test_tool_creation(): + """Test creating tools with parameters.""" + tool = Tool( + name="test_tool", + description="Test tool", + parameters={ + "param1": ToolParameter( + type="string", + description="Parameter 1" + ) + }, + required=["param1"] + ) + assert tool.name == "test_tool" + assert tool.description == "Test tool" + assert len(tool.parameters) == 1 + assert "param1" in tool.parameters + assert tool.required == ["param1"] + + +def test_browse_tool_schema(): + """Test browse tool schema structure.""" + assert BROWSE_TOOL.name == "browse_web" + assert "url" in BROWSE_TOOL.parameters + assert BROWSE_TOOL.parameters["url"].type == "string" + assert BROWSE_TOOL.required == ["url"] + + +def test_read_file_tool_schema(): + """Test read file tool schema structure.""" + assert READ_FILE_TOOL.name == "read_file" + assert "path" in READ_FILE_TOOL.parameters + assert READ_FILE_TOOL.parameters["path"].type == "string" + assert READ_FILE_TOOL.required == ["path"] + + +def test_write_file_tool_schema(): + """Test write file tool schema structure.""" + assert WRITE_FILE_TOOL.name == "write_file" + assert "path" in WRITE_FILE_TOOL.parameters + assert "content" in WRITE_FILE_TOOL.parameters + assert WRITE_FILE_TOOL.required == ["path", "content"] + + +def test_run_code_tool_schema(): + """Test run code tool schema structure.""" + assert RUN_CODE_TOOL.name == "run_code" + assert "code" in RUN_CODE_TOOL.parameters + assert "language" in RUN_CODE_TOOL.parameters + assert RUN_CODE_TOOL.parameters["language"].enum == ["python", "javascript", "bash"] + assert RUN_CODE_TOOL.required == ["code", "language"] + + +def test_claude_client_tool_schemas(): + """Test Claude client tool schema generation.""" + client = Claude() + + # Verify tool schemas are properly formatted for Claude API + assert len(client.tool_schemas) == len(AVAILABLE_TOOLS) + + # Check schema structure for first tool + schema = client.tool_schemas[0] + assert isinstance(schema, dict) + assert "name" in schema + assert "description" in schema + assert "parameters" in schema + assert schema["parameters"]["type"] == "object" + assert "properties" in schema["parameters"] + assert "required" in schema["parameters"] + + +@pytest.mark.parametrize("model_id,should_have_tools", [ + ("claude-3-opus-20240229", True), + ("claude-3-sonnet-20240229", True), + ("claude-2.1", False), + ("claude-2.0", False), +]) +def test_claude_inference_tool_inclusion(model_id: str, should_have_tools: bool): + """Test tool inclusion in Claude inference based on model.""" + client = Claude() + prompt = "Test prompt" + + # Mock the create method to capture kwargs + def mock_create(**kwargs) -> Dict[str, Any]: + class MockResponse: + content = [type("Content", (), {"text": "Mock response"})] + + # Verify tools presence based on model + if should_have_tools: + assert "tools" in kwargs + assert isinstance(kwargs["tools"], list) + assert len(kwargs["tools"]) == len(AVAILABLE_TOOLS) + else: + assert "tools" not in kwargs + + return MockResponse() + + # Replace create method with mock + client.client.messages.create = mock_create + + # Run inference + client.inference(model_id=model_id, prompt=prompt)