diff --git a/README.md b/README.md index 51a0c2c..9ccad4d 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,115 @@ mcp.mount() That's it! Your auto-generated MCP server is now available at `https://app.base.url/mcp`. +## MCP Prompts Support + +FastAPI-MCP automatically generates helpful prompts for each of your API endpoints and supports custom prompts for enhanced AI interactions. + +### Auto-Generated Tool Prompts + +By default, every FastAPI endpoint automatically gets a corresponding prompt (named `use_{endpoint_name}`) that provides AI models with guidance on how to use that specific tool: + +```python +from fastapi import FastAPI +from fastapi_mcp import FastApiMCP + +app = FastAPI() + +@app.post("/create_item") +async def create_item(name: str, price: float): + """Create a new item in the inventory.""" + return {"name": name, "price": price} + +# Auto-generation is enabled by default +mcp = FastApiMCP(app, auto_generate_prompts=True) # This is the default +mcp.mount() + +# Automatically creates a prompt named "use_create_item" with guidance +# on how to use the create_item tool effectively +``` + +#### Controlling Auto-Generation + +You have full control over prompt auto-generation: + +```python +# Option 1: Auto-generated prompts only (default) +mcp = FastApiMCP(app, auto_generate_prompts=True) + +# Option 2: Disable auto-generation, use only custom prompts +mcp = FastApiMCP(app, auto_generate_prompts=False) + +# Option 3: Mixed approach - auto-generated + custom overrides +mcp = FastApiMCP(app, auto_generate_prompts=True) +# Then add custom prompts or override auto-generated ones +``` + +### Custom Prompt Overrides + +You can override auto-generated prompts or create entirely custom ones: + +```python +# Override the auto-generated prompt for better guidance +@mcp.prompt("use_create_item", title="Item Creation Guide") +def create_item_guide(): + return PromptMessage( + role="user", + content=TextContent( + text="""Use the create_item tool to add items to inventory. + +Best Practices: +- Use descriptive names (e.g., "Wireless Bluetooth Mouse") +- Set realistic prices in decimal format (e.g., 29.99) +- Include detailed descriptions for better categorization + +This tool will validate inputs and return the created item details.""" + ) + ) +``` + +### API Documentation Prompts + +Create dynamic prompts that help with API understanding: + +```python +@mcp.prompt("api_documentation") +def api_docs_prompt(endpoint_path: Optional[str] = None): + if endpoint_path: + return PromptMessage( + role="user", + content=TextContent( + text=f"Please provide comprehensive documentation for {endpoint_path}, including parameters, examples, and use cases." + ) + ) + else: + # Generate overview of all endpoints + return PromptMessage( + role="user", + content=TextContent(text="Please explain this API's purpose and how to use its endpoints effectively.") + ) +``` + +### Welcome and Troubleshooting Prompts + +```python +@mcp.prompt("welcome") +def welcome_prompt(): + return PromptMessage( + role="user", + content=TextContent(text="Please provide a friendly welcome message for API users.") + ) + +@mcp.prompt("troubleshoot") +async def troubleshoot_prompt(error_message: str, endpoint: Optional[str] = None): + return PromptMessage( + role="user", + content=TextContent( + text=f"Help troubleshoot this API issue: {error_message}" + + (f" on endpoint {endpoint}" if endpoint else "") + ) + ) +``` + ## Documentation, Examples and Advanced Usage FastAPI-MCP provides [comprehensive documentation](https://fastapi-mcp.tadata.com/). Additionaly, check out the [examples directory](examples) for code samples demonstrating these features in action. diff --git a/examples/01_basic_usage_example.py b/examples/01_basic_usage_example.py index 470dab9..d387458 100644 --- a/examples/01_basic_usage_example.py +++ b/examples/01_basic_usage_example.py @@ -1,4 +1,4 @@ -from examples.shared.apps.items import app # The FastAPI app +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi_mcp import FastApiMCP @@ -15,4 +15,4 @@ if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/02_full_schema_description_example.py b/examples/02_full_schema_description_example.py index 9750c33..9210e18 100644 --- a/examples/02_full_schema_description_example.py +++ b/examples/02_full_schema_description_example.py @@ -1,8 +1,8 @@ - """ This example shows how to describe the full response schema instead of just a response example. """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi_mcp import FastApiMCP @@ -22,5 +22,5 @@ if __name__ == "__main__": import uvicorn - + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/03_custom_exposed_endpoints_example.py b/examples/03_custom_exposed_endpoints_example.py index 59e46e6..8d21ed8 100644 --- a/examples/03_custom_exposed_endpoints_example.py +++ b/examples/03_custom_exposed_endpoints_example.py @@ -6,7 +6,8 @@ - You can combine operation filtering with tag filtering (e.g., use `include_operations` with `include_tags`) - When combining filters, a greedy approach will be taken. Endpoints matching either criteria will be included """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi_mcp import FastApiMCP @@ -24,7 +25,7 @@ # Filter by excluding specific operation IDs exclude_operations_mcp = FastApiMCP( - app, + app, name="Item API MCP - Excluded Operations", exclude_operations=["create_item", "update_item", "delete_item"], ) diff --git a/examples/04_separate_server_example.py b/examples/04_separate_server_example.py index e468557..80f10da 100644 --- a/examples/04_separate_server_example.py +++ b/examples/04_separate_server_example.py @@ -2,6 +2,7 @@ This example shows how to run the MCP server and the FastAPI app separately. You can create an MCP server from one FastAPI app, and mount it to a different app. """ + from fastapi import FastAPI from examples.shared.apps.items import app @@ -30,4 +31,4 @@ if __name__ == "__main__": import uvicorn - uvicorn.run(mcp_app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(mcp_app, host="0.0.0.0", port=8000) diff --git a/examples/05_reregister_tools_example.py b/examples/05_reregister_tools_example.py index d30ce49..14e6f41 100644 --- a/examples/05_reregister_tools_example.py +++ b/examples/05_reregister_tools_example.py @@ -1,15 +1,16 @@ """ This example shows how to re-register tools if you add endpoints after the MCP server was created. """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi_mcp import FastApiMCP setup_logging() -mcp = FastApiMCP(app) # Add MCP server to the FastAPI app -mcp.mount() # MCP server +mcp = FastApiMCP(app) # Add MCP server to the FastAPI app +mcp.mount() # MCP server # This endpoint will not be registered as a tool, since it was added after the MCP instance was created @@ -24,5 +25,5 @@ async def new_endpoint(): if __name__ == "__main__": import uvicorn - + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/06_custom_mcp_router_example.py b/examples/06_custom_mcp_router_example.py index 83ea6ad..b69ac09 100644 --- a/examples/06_custom_mcp_router_example.py +++ b/examples/06_custom_mcp_router_example.py @@ -1,7 +1,8 @@ """ This example shows how to mount the MCP server to a specific APIRouter, giving a custom mount path. """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi import APIRouter @@ -9,7 +10,7 @@ setup_logging() -other_router = APIRouter(prefix="/other/route") +other_router = APIRouter(prefix="/other/route") app.include_router(other_router) mcp = FastApiMCP(app) @@ -21,5 +22,5 @@ if __name__ == "__main__": import uvicorn - + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/07_configure_http_timeout_example.py b/examples/07_configure_http_timeout_example.py index eaab570..036c103 100644 --- a/examples/07_configure_http_timeout_example.py +++ b/examples/07_configure_http_timeout_example.py @@ -2,7 +2,8 @@ This example shows how to configure the HTTP client timeout for the MCP server. In case you have API endpoints that take longer than 5 seconds to respond, you can increase the timeout. """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging import httpx @@ -12,14 +13,11 @@ setup_logging() -mcp = FastApiMCP( - app, - http_client=httpx.AsyncClient(timeout=20) -) +mcp = FastApiMCP(app, http_client=httpx.AsyncClient(timeout=20)) mcp.mount() if __name__ == "__main__": import uvicorn - + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/08_auth_example_token_passthrough.py b/examples/08_auth_example_token_passthrough.py index 8f0b8f4..422f4ab 100644 --- a/examples/08_auth_example_token_passthrough.py +++ b/examples/08_auth_example_token_passthrough.py @@ -21,7 +21,8 @@ } ``` """ -from examples.shared.apps.items import app # The FastAPI app + +from examples.shared.apps.items import app # The FastAPI app from examples.shared.setup import setup_logging from fastapi import Depends @@ -34,11 +35,13 @@ # Scheme for the Authorization header token_auth_scheme = HTTPBearer() + # Create a private endpoint @app.get("/private") -async def private(token = Depends(token_auth_scheme)): +async def private(token=Depends(token_auth_scheme)): return token.credentials + # Create the MCP server with the token auth scheme mcp = FastApiMCP( app, @@ -54,5 +57,5 @@ async def private(token = Depends(token_auth_scheme)): if __name__ == "__main__": import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/10_prompts_example.py b/examples/10_prompts_example.py new file mode 100644 index 0000000..79a25c0 --- /dev/null +++ b/examples/10_prompts_example.py @@ -0,0 +1,321 @@ +""" +Example demonstrating MCP Prompts support in FastAPI-MCP. + +This example shows: +1. How to create custom prompt templates for AI interactions +2. Three different approaches to controlling prompt auto-generation: + - Auto-generated prompts only (default) + - Custom prompts only (disabled auto-generation) + - Mixed approach (auto-generated + custom overrides) +3. API-related prompts including tool guidance and troubleshooting +4. How to override auto-generated prompts with enhanced custom versions +""" + +from typing import Optional +from fastapi import FastAPI + +from fastapi_mcp import FastApiMCP, PromptMessage, TextContent +from examples.shared.setup import setup_logging + +setup_logging() + +app = FastAPI( + title="Prompts Example API", description="An example API demonstrating MCP Prompts functionality", version="1.0.0" +) + +# Create MCP server with auto-generated prompts enabled (default: True) +# Set auto_generate_prompts=False to disable auto-generation and only use custom prompts +mcp = FastApiMCP(app, auto_generate_prompts=True) + +# ===== PROMPT CONTROL OPTIONS ===== +# This example demonstrates three different approaches to managing prompts: +# +# Option 1: Auto-generated prompts only (default behavior) +# mcp = FastApiMCP(app, auto_generate_prompts=True) # This is what we're using above +# +# Option 2: Custom prompts only (no auto-generation) +# mcp = FastApiMCP(app, auto_generate_prompts=False) +# # Then define only custom prompts with @mcp.prompt() +# +# Option 3: Mixed approach (auto-generated + custom overrides) +# mcp = FastApiMCP(app, auto_generate_prompts=True) # Auto-generate for all tools +# # Then override specific ones or add additional custom prompts + + +# Regular FastAPI endpoints (these will get auto-generated prompts) +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +@app.get("/items") +async def list_items(skip: int = 0, limit: int = 10): + """List all items with pagination.""" + # This would return actual items in a real app + return [{"id": i, "name": f"Item {i}"} for i in range(skip, skip + limit)] + + +@app.post("/items") +async def create_item(name: str, description: str, price: float): + """Create a new item.""" + return {"id": 123, "name": name, "description": description, "price": price} + + +# Example 1: Basic welcome prompt +@mcp.prompt("welcome", title="Welcome Message", description="Generate a friendly welcome message") +def welcome_prompt(): + """Generate a welcome message for API users.""" + return PromptMessage( + role="user", + content=TextContent(text="Please provide a warm and friendly welcome message for new users of our API."), + ) + + +# Example 2: Custom tool prompt override (demonstrates Option 3: Mixed approach) +# This overrides the auto-generated prompt for the create_item tool +@mcp.prompt("use_create_item", title="Create Item Tool Guide", description="Custom guidance for creating items") +def create_item_guide(): + """Override the default auto-generated prompt for the create_item tool. + + This demonstrates how you can enhance auto-generated prompts with custom, + domain-specific guidance while keeping auto-generation for other tools. + """ + return PromptMessage( + role="user", + content=TextContent( + text="""Use the create_item tool to add new items to the inventory system. + +**Best Practices:** +1. Always provide a unique, descriptive name for new items +2. Include a clear, detailed description explaining the item's purpose +3. Set a reasonable price (must be greater than 0) +4. Consider the target audience when naming and describing items + +**Parameter Guidelines:** +- **name**: Use clear, concise naming (e.g., "Wireless Bluetooth Headphones") +- **description**: Be specific about features and use cases +- **price**: Use decimal format for currency (e.g., 29.99) + +**Common Issues to Avoid:** +- Vague or unclear item names +- Missing or incomplete descriptions +- Negative or zero prices +- Duplicate item names + +**Example:** +``` +name: "Professional Wireless Mouse" +description: "Ergonomic wireless mouse with precision tracking, suitable for office work and gaming" +price: 45.99 +``` + +This tool will create the item and return the generated item with its assigned details. + +**Note**: This is a custom override of the auto-generated prompt. The other endpoints +(health_check, list_items) will use their auto-generated prompts for guidance. + """ + ), + ) + + +# Example 3: API documentation prompt +@mcp.prompt( + "api_documentation", + title="API Documentation Helper", + description="Generate comprehensive API documentation prompts", +) +def api_docs_prompt(endpoint_path: Optional[str] = None): + """Generate prompts for API documentation help.""" + if endpoint_path: + return PromptMessage( + role="user", + content=TextContent( + text=f"""Please provide comprehensive documentation for the {endpoint_path} endpoint. + +Include the following details: +1. **Purpose**: What this endpoint does and when to use it +2. **HTTP Method**: GET, POST, PUT, DELETE, etc. +3. **Parameters**: All required and optional parameters with types and descriptions +4. **Request Examples**: Sample requests with proper formatting +5. **Response Format**: Expected response structure and data types +6. **Status Codes**: Possible HTTP status codes and their meanings +7. **Error Handling**: Common errors and how to resolve them +8. **Use Cases**: Practical examples of when to use this endpoint + +Make the documentation clear and actionable for developers. + """ + ), + ) + else: + # Generate dynamic content based on current API routes + routes_info = [] + for route in app.routes: + if hasattr(route, "methods") and hasattr(route, "path"): + # Filter out internal routes + if not route.path.startswith("/mcp") and route.path != "/docs" and route.path != "/openapi.json": + methods = ", ".join(m for m in route.methods if m != "HEAD") + routes_info.append(f"- {methods} {route.path}") + + return PromptMessage( + role="user", + content=TextContent( + text=f"""Help me understand this API and create comprehensive documentation. + +**Available API Endpoints:** +{chr(10).join(routes_info)} + +Please provide: +1. **API Overview**: Purpose and main functionality of this API +2. **Getting Started**: How to begin using the API +3. **Endpoint Guide**: Brief description of what each endpoint does +4. **Common Workflows**: Step-by-step guides for typical use cases +5. **Best Practices**: Recommendations for effective API usage +6. **Error Handling**: How to handle common errors and edge cases + +**Focus Areas:** +- Make it beginner-friendly but comprehensive +- Include practical examples +- Explain the relationships between different endpoints +- Provide guidance on proper usage patterns + +Note: This API also supports MCP (Model Context Protocol) prompts to help with tool usage. + """ + ), + ) + + +# Example 4: API troubleshooting prompt +@mcp.prompt("troubleshoot", title="API Troubleshooting Assistant", description="Help troubleshoot API issues") +async def troubleshoot_prompt(error_message: str, endpoint: Optional[str] = None, status_code: Optional[int] = None): + """Generate troubleshooting prompts based on error information.""" + context_parts = [f"**Error Message**: {error_message}"] + + if endpoint: + context_parts.append(f"**Endpoint**: {endpoint}") + if status_code: + context_parts.append(f"**Status Code**: {status_code}") + + context = "\n".join(context_parts) + + return PromptMessage( + role="user", + content=TextContent( + text=f"""I'm experiencing an issue with this API: + +{context} + +Please help me troubleshoot this issue: + +1. **Root Cause Analysis**: What might be causing this error? +2. **Immediate Steps**: What should I check first? +3. **Resolution**: How can I fix this specific issue? +4. **Prevention**: How can I avoid this error in the future? +5. **Alternative Approaches**: Are there other ways to achieve the same goal? + +**Additional Context to Consider:** +- Check if all required parameters are provided +- Verify parameter types and formats +- Ensure proper authentication if required +- Confirm the endpoint URL is correct +- Review any rate limiting or quota restrictions + +Please provide specific, actionable advice based on the error details above. + """ + ), + ) + + +# ===== DEMONSTRATION OF ALL THREE CONTROL OPTIONS ===== + + +def demonstrate_prompt_control_options(): + """ + This function demonstrates all three prompt control options. + Uncomment the sections below to try different approaches. + """ + + # ===== OPTION 1: Auto-generated prompts only ===== + # This is what we're using in the main example above + print("Current setup: Auto-generated prompts enabled (default)") + print("- Creates 'use_health_check', 'use_list_items', 'use_create_item' prompts automatically") + print("- 'use_create_item' is overridden with our custom version") + print("- Also includes custom prompts: 'welcome', 'api_documentation', 'troubleshoot'") + + """ + # ===== OPTION 2: Custom prompts only ===== + # Uncomment this section to try custom-only approach + + from fastapi import FastAPI + + app_custom = FastAPI(title="Custom Prompts Only API") + + @app_custom.get("/users") + def list_users(): return [{"id": 1, "name": "User 1"}] + + @app_custom.post("/users") + def create_user(name: str): return {"id": 2, "name": name} + + # Disable auto-generation + mcp_custom = FastApiMCP(app_custom, auto_generate_prompts=False) + + @mcp_custom.prompt("user_management_help") + def user_help(): + return PromptMessage( + role="user", + content=TextContent(text="Help me manage users effectively...") + ) + + @mcp_custom.prompt("create_user_workflow") + def user_workflow(): + return PromptMessage( + role="user", + content=TextContent(text="Guide me through user creation...") + ) + + mcp_custom.mount("/custom-mcp") + print("Custom-only setup would have only 2 prompts: user_management_help, create_user_workflow") + """ + + """ + # ===== OPTION 3: Mixed approach (what we're demonstrating above) ===== + # This is the approach used in our main example: + # - Auto-generate prompts for all tools (auto_generate_prompts=True) + # - Override specific auto-generated prompts (use_create_item) + # - Add additional custom prompts (welcome, api_documentation, troubleshoot) + + # This gives you: + # Auto-generated: use_health_check, use_list_items, use_create_item (overridden) + # Custom: welcome, api_documentation, troubleshoot, use_create_item (custom version) + """ + + +# Mount the MCP server (this will auto-generate prompts for all tools) +mcp.mount() + +# Print information about the current setup +if __name__ == "__main__": + print("\n" + "=" * 60) + print("FastAPI-MCP Prompts Example") + print("=" * 60) + + demonstrate_prompt_control_options() + + print(f"\nTotal prompts available: {len(mcp.prompt_registry.get_prompt_list())}") + print("\nPrompt names:") + for prompt in mcp.prompt_registry.get_prompt_list(): + prompt_type = "Auto-generated" if prompt.name.startswith("use_") else "Custom" + if prompt.name == "use_create_item": + prompt_type += " (Overridden)" + print(f" - {prompt.name} ({prompt_type})") + + print("\n" + "=" * 60) + print("Choose your preferred approach:") + print("1. Auto-generated only: FastApiMCP(app, auto_generate_prompts=True)") + print("2. Custom only: FastApiMCP(app, auto_generate_prompts=False)") + print("3. Mixed (current): Auto-generate + custom overrides/additions") + print("=" * 60) + + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/README.md b/examples/README.md index 494a946..02704c1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -9,3 +9,6 @@ The following examples demonstrate various features and usage patterns of FastAP 5. [Reregister Tools](05_reregister_tools_example.py) - Adding endpoints after MCP server creation 6. [Custom MCP Router](06_custom_mcp_router_example.py) - Advanced routing configuration 7. [Configure HTTP Timeout](07_configure_http_timeout_example.py) - Customizing timeout settings +8. [Authentication Example - Token Passthrough](08_auth_example_token_passthrough.py) - Basic token authentication +9. [Authentication Example - Auth0](09_auth_example_auth0.py) - OAuth with Auth0 +10. [MCP Prompts Example](10_prompts_example.py) - Using MCP Prompts for structured AI interactions and controlling auto-generation diff --git a/fastapi_mcp/__init__.py b/fastapi_mcp/__init__.py index f748712..7f041cb 100644 --- a/fastapi_mcp/__init__.py +++ b/fastapi_mcp/__init__.py @@ -13,11 +13,17 @@ __version__ = "0.0.0.dev0" # pragma: no cover from .server import FastApiMCP -from .types import AuthConfig, OAuthMetadata +from .types import AuthConfig, OAuthMetadata, PromptMessage, TextContent, ImageContent, AudioContent +from .prompts import PromptRegistry __all__ = [ "FastApiMCP", "AuthConfig", "OAuthMetadata", + "PromptMessage", + "TextContent", + "ImageContent", + "AudioContent", + "PromptRegistry", ] diff --git a/fastapi_mcp/prompts.py b/fastapi_mcp/prompts.py new file mode 100644 index 0000000..edc9721 --- /dev/null +++ b/fastapi_mcp/prompts.py @@ -0,0 +1,277 @@ +""" +MCP Prompts support for FastAPI-MCP. + +This module provides FastAPI-native decorators and utilities for defining +MCP-compliant prompts that can be discovered and executed by MCP clients. +""" + +import logging +from typing import Callable, List, Dict, Any, Optional, get_type_hints, Union +from inspect import signature, Parameter, iscoroutinefunction +import mcp.types as types + +from .types import PromptMessage, PromptArgument, TextContent + +logger = logging.getLogger(__name__) + + +class PromptRegistry: + """Registry for managing MCP prompts in a FastAPI application.""" + + def __init__(self): + self.prompts: Dict[str, Dict[str, Any]] = {} + + def register_prompt( + self, name: str, title: Optional[str] = None, description: Optional[str] = None, func: Optional[Callable] = None + ): + """ + Register a prompt function with the registry. + + Args: + name: Unique identifier for the prompt + title: Human-readable title for the prompt + description: Description of what the prompt does + func: The prompt function to register + """ + + def decorator(func: Callable) -> Callable: + # Extract argument schema from function signature + sig = signature(func) + type_hints = get_type_hints(func) + + arguments = [] + properties = {} + required = [] + + # Process function parameters to create prompt arguments + for param_name, param in sig.parameters.items(): + if param_name in ["self", "cls"]: # Skip self/cls parameters + continue + + param_type = type_hints.get(param_name, str) + is_required = param.default == Parameter.empty + param_desc = f"Parameter {param_name}" + + # Try to extract description from docstring or annotations + if hasattr(param_type, "__doc__") and param_type.__doc__: + param_desc = param_type.__doc__ + + arguments.append(PromptArgument(name=param_name, description=param_desc, required=is_required)) + + # Create JSON schema property + properties[param_name] = self._type_to_json_schema(param_type) + if is_required: + required.append(param_name) + + # Store prompt definition + self.prompts[name] = { + "name": name, + "title": title or name.replace("_", " ").title(), + "description": description or func.__doc__ or f"Prompt: {name}", + "arguments": arguments, + "func": func, + "input_schema": { + "type": "object", + "properties": properties, + "required": required if required else None, + }, + } + + logger.debug(f"Registered prompt: {name}") + return func + + if func is None: + return decorator + else: + return decorator(func) + + def _type_to_json_schema(self, param_type: type) -> Dict[str, Any]: + """Convert Python type to JSON schema property.""" + if param_type is str: + return {"type": "string"} + elif param_type is int: + return {"type": "integer"} + elif param_type is float: + return {"type": "number"} + elif param_type is bool: + return {"type": "boolean"} + elif hasattr(param_type, "__origin__"): + # Handle generic types like List, Optional, etc. + if param_type.__origin__ is list: + return {"type": "array", "items": {"type": "string"}} + elif param_type.__origin__ is Union: + # Handle Optional types (Union[T, None]) + if hasattr(param_type, "__args__"): + args = param_type.__args__ + if len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + return self._type_to_json_schema(non_none_type) + + # Default to string for unknown types + return {"type": "string"} + + def get_prompt_list(self) -> List[types.Prompt]: + """Get list of all registered prompts in MCP format.""" + mcp_prompts = [] + + for prompt_def in self.prompts.values(): + # Convert our PromptArgument objects to MCP format + mcp_arguments = [] + for arg in prompt_def["arguments"]: + mcp_arguments.append( + types.PromptArgument(name=arg.name, description=arg.description, required=arg.required) + ) + + mcp_prompts.append( + types.Prompt(name=prompt_def["name"], description=prompt_def["description"], arguments=mcp_arguments) + ) + + return mcp_prompts + + async def get_prompt(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> List[types.PromptMessage]: + """ + Execute a prompt function and return the result. + + Args: + name: Name of the prompt to execute + arguments: Arguments to pass to the prompt function + + Returns: + List of prompt messages in MCP format + + Raises: + ValueError: If prompt is not found + """ + if name not in self.prompts: + raise ValueError(f"Prompt '{name}' not found") + + prompt_def = self.prompts[name] + func = prompt_def["func"] + args = arguments or {} + + try: + # Call the prompt function + if iscoroutinefunction(func): + result = await func(**args) + else: + result = func(**args) + + # Ensure result is a list + if not isinstance(result, list): + result = [result] + + # Convert our PromptMessage objects to MCP format + mcp_messages = [] + for msg in result: + if isinstance(msg, PromptMessage): + # Convert content to MCP format + mcp_content: Union[types.TextContent, types.ImageContent, types.EmbeddedResource] + if hasattr(msg.content, "type"): + if msg.content.type == "text": + mcp_content = types.TextContent(type="text", text=msg.content.text) + elif msg.content.type == "image": + mcp_content = types.ImageContent( + type="image", data=msg.content.data, mimeType=msg.content.mimeType + ) + elif msg.content.type == "audio": + # Note: mcp.types may not have AudioContent, so we'll use TextContent as fallback + mcp_content = types.TextContent( + type="text", text=f"[Audio content: {msg.content.mimeType}]" + ) + else: + mcp_content = types.TextContent(type="text", text=str(msg.content)) + else: + mcp_content = types.TextContent(type="text", text=str(msg.content)) + + mcp_messages.append(types.PromptMessage(role=msg.role, content=mcp_content)) + else: + # Handle string or other simple types + mcp_messages.append( + types.PromptMessage(role="user", content=types.TextContent(type="text", text=str(msg))) + ) + + return mcp_messages + + except Exception as e: + logger.error(f"Error executing prompt '{name}': {e}") + raise ValueError(f"Error executing prompt '{name}': {str(e)}") + + def has_prompts(self) -> bool: + """Check if any prompts are registered.""" + return len(self.prompts) > 0 + + def auto_register_tool_prompts(self, tools: List[types.Tool], operation_map: Dict[str, Dict[str, Any]]) -> None: + """ + Automatically register default prompts for each tool. + + Args: + tools: List of MCP tools to create prompts for + operation_map: Mapping of operation IDs to operation details + """ + for tool in tools: + prompt_name = f"use_{tool.name}" + + # Skip if user has already registered a custom prompt with this name + if prompt_name in self.prompts: + logger.debug(f"Skipping auto-registration for {prompt_name} - custom prompt exists") + continue + + # Generate prompt content for this tool + prompt_content = self._generate_tool_prompt_content(tool, operation_map.get(tool.name, {})) + + # Create a simple prompt function + def create_tool_prompt(content: str): + def tool_prompt_func(): + return PromptMessage(role="user", content=TextContent(type="text", text=content)) + + return tool_prompt_func + + # Register the auto-generated prompt + self.prompts[prompt_name] = { + "name": prompt_name, + "title": f"Usage Guide: {tool.name}", + "description": f"Best practices and guidance for using the {tool.name} tool effectively", + "arguments": [], + "func": create_tool_prompt(prompt_content), + "input_schema": {"type": "object", "properties": {}, "required": []}, + "auto_generated": True, + } + + logger.debug(f"Auto-registered prompt: {prompt_name}") + + def _generate_tool_prompt_content(self, tool: types.Tool, operation_info: Dict[str, Any]) -> str: + """ + Generate helpful prompt content for a tool. + + Args: + tool: The MCP tool to generate content for + operation_info: Operation details from the operation map + + Returns: + Generated prompt content as a string + """ + # Focus on actionable guidance rather than repeating tool information + content_parts = [ + f"You are about to use the **{tool.name}** tool.", + "", + "**Key Guidelines:**", + "• Review the tool's description and parameter requirements carefully", + "• Provide all required parameters with appropriate values", + "• Use relevant data that matches the user's actual needs and context", + "• Check the expected response format before interpreting results", + "", + "**Best Practices:**", + "• Validate your inputs match the expected parameter types", + "• Use values that make sense for the user's specific request", + "• Handle potential errors gracefully", + "• Consider the business logic and constraints of the operation", + "", + "**Execution Tips:**", + "• Double-check required vs optional parameters", + "• Use appropriate data formats (strings, numbers, booleans)", + "• Consider edge cases and boundary conditions", + "", + "💡 **Pro Tip**: The tool description and schema contain all technical details. Focus on using parameters that are relevant to the user's specific request and goals.", + ] + + return "\n".join(content_parts) diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index f5c4fc6..96edaf6 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -11,6 +11,7 @@ from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools from fastapi_mcp.transport.sse import FastApiSseTransport from fastapi_mcp.types import HTTPRequestInfo, AuthConfig +from fastapi_mcp.prompts import PromptRegistry import logging @@ -109,6 +110,10 @@ def __init__( Optional[List[str]], Doc("List of tags to exclude from MCP tools. Cannot be used with include_tags."), ] = None, + auto_generate_prompts: Annotated[ + bool, + Doc("Whether to automatically generate default prompts for each tool. Defaults to True."), + ] = True, auth_config: Annotated[ Optional[AuthConfig], Doc("Configuration for MCP authentication"), @@ -124,6 +129,7 @@ def __init__( self.operation_map: Dict[str, Dict[str, Any]] self.tools: List[types.Tool] self.server: Server + self.prompt_registry: PromptRegistry self.fastapi = fastapi self.name = name or self.fastapi.title or "FastAPI MCP" @@ -136,6 +142,7 @@ def __init__( self._exclude_operations = exclude_operations self._include_tags = include_tags self._exclude_tags = exclude_tags + self._auto_generate_prompts = auto_generate_prompts self._auth_config = auth_config if self._auth_config: @@ -147,6 +154,9 @@ def __init__( timeout=10.0, ) + # Initialize prompt registry + self.prompt_registry = PromptRegistry() + self.setup_server() def setup_server(self) -> None: @@ -167,6 +177,10 @@ def setup_server(self) -> None: # Filter tools based on operation IDs and tags self.tools = self._filter_tools(all_tools, openapi_schema) + # Auto-register default prompts for each tool if enabled + if self._auto_generate_prompts: + self.prompt_registry.auto_register_tool_prompts(self.tools, self.operation_map) + mcp_server: LowlevelMCPServer = LowlevelMCPServer(self.name, self.description) @mcp_server.list_tools() @@ -185,8 +199,34 @@ async def handle_call_tool( http_request_info=http_request_info, ) + # Add prompt handlers + @mcp_server.list_prompts() + async def handle_list_prompts() -> List[types.Prompt]: + return self.prompt_registry.get_prompt_list() + + @mcp_server.get_prompt() + async def handle_get_prompt(name: str, arguments: Optional[Dict[str, Any]] = None) -> types.GetPromptResult: + messages = await self.prompt_registry.get_prompt(name, arguments) + return types.GetPromptResult(description=f"Prompt: {name}", messages=messages) + self.server = mcp_server + def prompt(self, name: str, title: Optional[str] = None, description: Optional[str] = None): + """ + Decorator to register a prompt function. + + Args: + name: Unique identifier for the prompt + title: Human-readable title for the prompt + description: Description of what the prompt does + + Example: + @mcp.prompt("code_review", title="Code Review", description="Review code for issues") + async def code_review(code: str, language: str = "python"): + return [PromptMessage(role="user", content=TextContent(text=f"Review this {language} code: {code}"))] + """ + return self.prompt_registry.register_prompt(name, title, description) + def _register_mcp_connection_endpoint_sse( self, router: FastAPI | APIRouter, diff --git a/fastapi_mcp/types.py b/fastapi_mcp/types.py index 2e8cf2e..9815e8b 100644 --- a/fastapi_mcp/types.py +++ b/fastapi_mcp/types.py @@ -382,3 +382,61 @@ class ClientRegistrationResponse(BaseType): grant_types: List[str] token_endpoint_auth_method: str client_name: str + + +# MCP Prompts Support +class PromptArgument(BaseType): + """Argument definition for MCP prompts""" + + name: str + description: Optional[str] = None + required: bool = False + + +class TextContent(BaseType): + """Text content for prompt messages""" + + type: Literal["text"] = "text" + text: str + + +class ImageContent(BaseType): + """Image content for prompt messages""" + + type: Literal["image"] = "image" + data: str # base64 encoded + mimeType: str + + +class AudioContent(BaseType): + """Audio content for prompt messages""" + + type: Literal["audio"] = "audio" + data: str # base64 encoded + mimeType: str + + +class ResourceContent(BaseType): + """Resource content for prompt messages""" + + type: Literal["resource"] = "resource" + resource: Dict[str, Any] + + +PromptContent = Union[TextContent, ImageContent, AudioContent, ResourceContent] + + +class PromptMessage(BaseType): + """Message in a prompt conversation""" + + role: Literal["user", "assistant"] + content: PromptContent + + +class PromptDefinition(BaseType): + """Complete prompt definition""" + + name: str + title: Optional[str] = None + description: Optional[str] = None + arguments: List[PromptArgument] = [] diff --git a/tests/test_mcp_execute_api_tool.py b/tests/test_mcp_execute_api_tool.py index cc05d34..a492f65 100644 --- a/tests/test_mcp_execute_api_tool.py +++ b/tests/test_mcp_execute_api_tool.py @@ -10,183 +10,150 @@ async def test_execute_api_tool_success(simple_fastapi_app: FastAPI): """Test successful execution of an API tool.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "Test Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "Test Item"}' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].text == '{\n "id": 1,\n "name": "Test Item"\n}' - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_query_params(simple_fastapi_app: FastAPI): """Test execution of an API tool with query parameters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] mock_response.status_code = 200 mock_response.text = '[{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "list_items" arguments = {"skip": 0, "limit": 2} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with query parameters - mock_client.get.assert_called_once_with( - "/items/", - params={"skip": 0, "limit": 2}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/", params={"skip": 0, "limit": 2}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_body(simple_fastapi_app: FastAPI): """Test execution of an API tool with request body.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "New Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "New Item"}' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.post.return_value = mock_response - + # Test parameters tool_name = "create_item" arguments = { - "item": { - "id": 1, - "name": "New Item", - "price": 10.0, - "tags": ["tag1"], - "description": "New item description" - } + "item": {"id": 1, "name": "New Item", "price": 10.0, "tags": ["tag1"], "description": "New item description"} } - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with the request body - mock_client.post.assert_called_once_with( - "/items/", - params={}, - headers={}, - json=arguments - ) + mock_client.post.assert_called_once_with("/items/", params={}, headers={}, json=arguments) @pytest.mark.asyncio async def test_execute_api_tool_with_non_ascii_chars(simple_fastapi_app: FastAPI): """Test execution of an API tool with non-ASCII characters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Test data with both ASCII and non-ASCII characters test_data = { "id": 1, "name": "你好 World", # Chinese characters + ASCII "price": 10.0, "tags": ["tag1", "标签2"], # Chinese characters in tags - "description": "这是一个测试描述" # All Chinese characters + "description": "这是一个测试描述", # All Chinese characters } - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = test_data mock_response.status_code = 200 - mock_response.text = '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' - + mock_response.text = ( + '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' + ) + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify that the response contains both ASCII and non-ASCII characters response_text = result[0].text assert "你好" in response_text # Chinese characters preserved assert "World" in response_text # ASCII characters preserved assert "标签2" in response_text # Chinese characters in tags preserved assert "这是一个测试描述" in response_text # All Chinese description preserved - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) diff --git a/tests/test_prompts.py b/tests/test_prompts.py new file mode 100644 index 0000000..1e423c1 --- /dev/null +++ b/tests/test_prompts.py @@ -0,0 +1,411 @@ +""" +Tests for MCP Prompts functionality. +""" + +import pytest +from typing import List, Optional +from fastapi import FastAPI + +from fastapi_mcp import FastApiMCP, PromptMessage, TextContent, ImageContent +from fastapi_mcp.prompts import PromptRegistry + + +class TestPromptRegistry: + """Test the PromptRegistry class.""" + + def test_empty_registry(self): + """Test that empty registry works correctly.""" + registry = PromptRegistry() + assert not registry.has_prompts() + assert registry.get_prompt_list() == [] + + def test_register_simple_prompt(self): + """Test registering a simple prompt function.""" + registry = PromptRegistry() + + @registry.register_prompt("test_prompt", "Test Prompt", "A test prompt") + def simple_prompt(): + return PromptMessage(role="user", content=TextContent(text="Hello, world!")) + + assert registry.has_prompts() + prompts = registry.get_prompt_list() + assert len(prompts) == 1 + + prompt = prompts[0] + assert prompt.name == "test_prompt" + assert prompt.description == "A test prompt" + assert len(prompt.arguments) == 0 + + def test_register_prompt_with_parameters(self): + """Test registering a prompt with parameters.""" + registry = PromptRegistry() + + @registry.register_prompt("param_prompt", "Parameterized Prompt") + def param_prompt(message: str, count: int = 1): + return PromptMessage(role="user", content=TextContent(text=f"Message: {message}, Count: {count}")) + + prompts = registry.get_prompt_list() + assert len(prompts) == 1 + + prompt = prompts[0] + assert len(prompt.arguments) == 2 + + # Check required parameter + message_arg = next(arg for arg in prompt.arguments if arg.name == "message") + assert message_arg.required is True + + # Check optional parameter + count_arg = next(arg for arg in prompt.arguments if arg.name == "count") + assert count_arg.required is False + + @pytest.mark.asyncio + async def test_execute_prompt(self): + """Test executing a registered prompt.""" + registry = PromptRegistry() + + @registry.register_prompt("echo_prompt") + def echo_prompt(text: str): + return PromptMessage(role="user", content=TextContent(text=f"Echo: {text}")) + + messages = await registry.get_prompt("echo_prompt", {"text": "Hello"}) + assert len(messages) == 1 + + message = messages[0] + assert message.role == "user" + assert hasattr(message.content, "text") + assert "Echo: Hello" in message.content.text + + @pytest.mark.asyncio + async def test_execute_async_prompt(self): + """Test executing an async prompt function.""" + registry = PromptRegistry() + + @registry.register_prompt("async_prompt") + async def async_prompt(name: str): + return PromptMessage(role="user", content=TextContent(text=f"Hello, {name}!")) + + messages = await registry.get_prompt("async_prompt", {"name": "World"}) + assert len(messages) == 1 + assert "Hello, World!" in messages[0].content.text + + @pytest.mark.asyncio + async def test_execute_nonexistent_prompt(self): + """Test executing a prompt that doesn't exist.""" + registry = PromptRegistry() + + with pytest.raises(ValueError, match="Prompt 'missing' not found"): + await registry.get_prompt("missing") + + @pytest.mark.asyncio + async def test_prompt_returns_list(self): + """Test prompt that returns multiple messages.""" + registry = PromptRegistry() + + @registry.register_prompt("multi_prompt") + def multi_prompt(): + return [ + PromptMessage(role="user", content=TextContent(text="First message")), + PromptMessage(role="assistant", content=TextContent(text="Second message")), + ] + + messages = await registry.get_prompt("multi_prompt") + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[1].role == "assistant" + + +class TestFastAPIMCPPrompts: + """Test prompts integration with FastApiMCP.""" + + def test_fastapi_mcp_has_prompt_decorator(self): + """Test that FastApiMCP has a prompt decorator.""" + app = FastAPI() + mcp = FastApiMCP(app) + + assert hasattr(mcp, "prompt") + assert hasattr(mcp, "prompt_registry") + assert isinstance(mcp.prompt_registry, PromptRegistry) + + def test_fastapi_mcp_prompt_registration(self): + """Test registering prompts through FastApiMCP.""" + app = FastAPI() + mcp = FastApiMCP(app) + + @mcp.prompt("test_prompt", title="Test", description="Test prompt") + def test_prompt(input_text: str): + return PromptMessage(role="user", content=TextContent(text=f"Input: {input_text}")) + + assert mcp.prompt_registry.has_prompts() + prompts = mcp.prompt_registry.get_prompt_list() + assert len(prompts) == 1 + assert prompts[0].name == "test_prompt" + + @pytest.mark.asyncio + async def test_fastapi_mcp_prompt_execution(self): + """Test executing prompts through FastApiMCP.""" + app = FastAPI() + mcp = FastApiMCP(app) + + @mcp.prompt("greet", description="Greeting prompt") + def greet_prompt(name: str, greeting: str = "Hello"): + return PromptMessage(role="user", content=TextContent(text=f"{greeting}, {name}!")) + + messages = await mcp.prompt_registry.get_prompt("greet", {"name": "Alice", "greeting": "Hi"}) + + assert len(messages) == 1 + assert "Hi, Alice!" in messages[0].content.text + + +class TestPromptTypes: + """Test prompt-related type definitions.""" + + def test_text_content_creation(self): + """Test creating TextContent.""" + content = TextContent(text="Hello, world!") + assert content.type == "text" + assert content.text == "Hello, world!" + + def test_image_content_creation(self): + """Test creating ImageContent.""" + content = ImageContent(data="base64data", mimeType="image/png") + assert content.type == "image" + assert content.data == "base64data" + assert content.mimeType == "image/png" + + def test_prompt_message_creation(self): + """Test creating PromptMessage.""" + message = PromptMessage(role="user", content=TextContent(text="Test message")) + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Test message" + + +class TestPromptComplexScenarios: + """Test complex prompt scenarios.""" + + @pytest.mark.asyncio + async def test_prompt_with_complex_types(self): + """Test prompt with complex parameter types.""" + registry = PromptRegistry() + + @registry.register_prompt("complex_prompt") + def complex_prompt(items: List[str], count: Optional[int] = None, enabled: bool = True): + text = f"Items: {items}, Count: {count}, Enabled: {enabled}" + return PromptMessage(role="user", content=TextContent(text=text)) + + prompts = registry.get_prompt_list() + prompt = prompts[0] + + # Check that we have the right number of arguments + assert len(prompt.arguments) == 3 + + # Execute the prompt + messages = await registry.get_prompt("complex_prompt", {"items": ["a", "b", "c"], "count": 5, "enabled": False}) + + assert len(messages) == 1 + assert "Items: ['a', 'b', 'c']" in messages[0].content.text + + @pytest.mark.asyncio + async def test_prompt_error_handling(self): + """Test error handling in prompt execution.""" + registry = PromptRegistry() + + @registry.register_prompt("error_prompt") + def error_prompt(): + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Error executing prompt 'error_prompt'"): + await registry.get_prompt("error_prompt") + + +class TestAutoGeneratedToolPrompts: + """Test auto-generation of tool prompts.""" + + def test_auto_register_tool_prompts(self): + """Test that tool prompts are auto-registered.""" + from fastapi import FastAPI + from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools + from fastapi.openapi.utils import get_openapi + + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + """Test endpoint.""" + return {"message": "test"} + + # Generate OpenAPI schema and convert to tools + openapi_schema = get_openapi(title=app.title, version=app.version, routes=app.routes) + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + registry = PromptRegistry() + registry.auto_register_tool_prompts(tools, operation_map) + + # Check that auto-generated prompts exist + assert registry.has_prompts() + prompts = registry.get_prompt_list() + + # Should have one prompt for the test endpoint + tool_prompts = [p for p in prompts if p.name.startswith("use_")] + assert len(tool_prompts) >= 1 + + # Check the auto-generated prompt has correct content + use_test_prompt = tool_prompts[0] + assert "Best practices and guidance" in use_test_prompt.description + + @pytest.mark.asyncio + async def test_auto_generated_prompt_execution(self): + """Test executing an auto-generated prompt.""" + from fastapi import FastAPI + from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools + from fastapi.openapi.utils import get_openapi + + app = FastAPI() + + @app.post("/create_item") + async def create_item(name: str, price: float): + """Create a new item.""" + return {"name": name, "price": price} + + # Generate tools and auto-register prompts + openapi_schema = get_openapi(title=app.title, version=app.version, routes=app.routes) + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + registry = PromptRegistry() + registry.auto_register_tool_prompts(tools, operation_map) + + # Find and execute the auto-generated prompt + prompts = registry.get_prompt_list() + tool_prompts = [p for p in prompts if p.name.startswith("use_")] + assert len(tool_prompts) >= 1 + + # Execute the first auto-generated prompt + prompt_name = tool_prompts[0].name + messages = await registry.get_prompt(prompt_name) + + assert len(messages) == 1 + message = messages[0] + assert message.role == "user" + assert "Key Guidelines" in message.content.text + assert "Best Practices" in message.content.text + + +class TestPromptAutoGenerationControl: + """Test controlling auto-generation of prompts.""" + + def test_auto_generate_prompts_disabled(self): + """Test that auto-generation can be disabled.""" + from fastapi import FastAPI + + app = FastAPI(title="Test App") + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Create MCP server with auto-generation disabled + mcp = FastApiMCP(app, auto_generate_prompts=False) + + # Should have no auto-generated prompts + prompts = mcp.prompt_registry.get_prompt_list() + assert len(prompts) == 0 + + # But should still be able to add custom prompts + @mcp.prompt("custom_prompt") + def custom(): + return PromptMessage(role="user", content=TextContent(text="Custom prompt")) + + prompts = mcp.prompt_registry.get_prompt_list() + assert len(prompts) == 1 + assert prompts[0].name == "custom_prompt" + + def test_auto_generate_prompts_enabled_by_default(self): + """Test that auto-generation is enabled by default.""" + from fastapi import FastAPI + + app = FastAPI(title="Test App") + + @app.get("/test") + def test_endpoint(): + return {"message": "test"} + + # Create MCP server with default settings + mcp = FastApiMCP(app) + + # Should have auto-generated prompts + prompts = mcp.prompt_registry.get_prompt_list() + assert len(prompts) > 0 + + # Check that the auto-generated prompt exists + prompt_names = [p.name for p in prompts] + assert any("use_test_endpoint" in name for name in prompt_names) + + def test_custom_prompt_overrides_auto_generated(self): + """Test that custom prompts can override auto-generated ones.""" + from fastapi import FastAPI + + app = FastAPI(title="Test App") + + @app.post("/items") + def create_item(name: str): + return {"name": name} + + # Create MCP server with auto-generation enabled + mcp = FastApiMCP(app, auto_generate_prompts=True) + + # Get the auto-generated prompt first + prompts = mcp.prompt_registry.get_prompt_list() + auto_prompt = next(p for p in prompts if "use_create_item" in p.name) + + # Override with custom prompt using the same name + @mcp.prompt(auto_prompt.name, title="Custom Override", description="Custom override description") + def custom_override(): + return PromptMessage(role="user", content=TextContent(text="Custom override content")) + + # Should still have the same number of prompts (override, not add) + prompts_after = mcp.prompt_registry.get_prompt_list() + assert len(prompts_after) == len(prompts) + + # But the prompt should now have the custom description + overridden_prompt = next(p for p in prompts_after if p.name == auto_prompt.name) + assert overridden_prompt.description == "Custom override description" + + def test_mixed_auto_and_custom_prompts(self): + """Test mixing auto-generated and custom prompts.""" + from fastapi import FastAPI + + app = FastAPI(title="Test App") + + @app.get("/users") + def list_users(): + return [{"id": 1, "name": "User 1"}] + + @app.post("/users") + def create_user(name: str): + return {"id": 2, "name": name} + + # Create MCP server with auto-generation enabled + mcp = FastApiMCP(app, auto_generate_prompts=True) + + # Should have auto-generated prompts for both endpoints + initial_prompts = mcp.prompt_registry.get_prompt_list() + auto_prompt_count = len([p for p in initial_prompts if p.name.startswith("use_")]) + assert auto_prompt_count >= 2 # At least one for each endpoint + + # Add custom prompts + @mcp.prompt("user_guide", title="User Management Guide") + def user_guide(): + return PromptMessage(role="user", content=TextContent(text="User management guidance")) + + @mcp.prompt("api_overview", title="API Overview") + def api_overview(): + return PromptMessage(role="user", content=TextContent(text="API overview")) + + # Should now have auto-generated + custom prompts + final_prompts = mcp.prompt_registry.get_prompt_list() + final_auto_count = len([p for p in final_prompts if p.name.startswith("use_")]) + final_custom_count = len([p for p in final_prompts if not p.name.startswith("use_")]) + + assert final_auto_count == auto_prompt_count # Same number of auto-generated + assert final_custom_count == 2 # Two custom prompts added + assert len(final_prompts) == auto_prompt_count + 2