diff --git a/api/api.py b/api/api.py index f3cb65a8..558bcda4 100644 --- a/api/api.py +++ b/api/api.py @@ -6,7 +6,19 @@ from typing import List, Optional, Dict, Any, Literal import json from datetime import datetime -from pydantic import BaseModel, Field +from api.schemas import ( + WikiPage, + ProcessedProjectEntry, + RepoInfo, + WikiStructureModel, + WikiCacheData, + WikiCacheRequest, + WikiExportRequest, + Model, + Provider, + ModelConfig, + AuthorizationConfig, +) import google.generativeai as genai import asyncio @@ -36,101 +48,6 @@ def get_adalflow_default_root_path(): return os.path.expanduser(os.path.join("~", ".adalflow")) -# --- Pydantic Models --- -class WikiPage(BaseModel): - """ - Model for a wiki page. - """ - id: str - title: str - content: str - filePaths: List[str] - importance: str # Should ideally be Literal['high', 'medium', 'low'] - relatedPages: List[str] - -class ProcessedProjectEntry(BaseModel): - id: str # Filename - owner: str - repo: str - name: str # owner/repo - repo_type: str # Renamed from type to repo_type for clarity with existing models - submittedAt: int # Timestamp - language: str # Extracted from filename - -class RepoInfo(BaseModel): - owner: str - repo: str - type: str - token: Optional[str] = None - localPath: Optional[str] = None - repoUrl: Optional[str] = None - - -class WikiStructureModel(BaseModel): - """ - Model for the overall wiki structure. - """ - id: str - title: str - description: str - pages: List[WikiPage] - -class WikiCacheData(BaseModel): - """ - Model for the data to be stored in the wiki cache. - """ - wiki_structure: WikiStructureModel - generated_pages: Dict[str, WikiPage] - repo_url: Optional[str] = None #compatible for old cache - repo: Optional[RepoInfo] = None - provider: Optional[str] = None - model: Optional[str] = None - -class WikiCacheRequest(BaseModel): - """ - Model for the request body when saving wiki cache. - """ - repo: RepoInfo - language: str - wiki_structure: WikiStructureModel - generated_pages: Dict[str, WikiPage] - provider: str - model: str - -class WikiExportRequest(BaseModel): - """ - Model for requesting a wiki export. - """ - repo_url: str = Field(..., description="URL of the repository") - pages: List[WikiPage] = Field(..., description="List of wiki pages to export") - format: Literal["markdown", "json"] = Field(..., description="Export format (markdown or json)") - -# --- Model Configuration Models --- -class Model(BaseModel): - """ - Model for LLM model configuration - """ - id: str = Field(..., description="Model identifier") - name: str = Field(..., description="Display name for the model") - -class Provider(BaseModel): - """ - Model for LLM provider configuration - """ - id: str = Field(..., description="Provider identifier") - name: str = Field(..., description="Display name for the provider") - models: List[Model] = Field(..., description="List of available models for this provider") - supportsCustomModel: Optional[bool] = Field(False, description="Whether this provider supports custom models") - -class ModelConfig(BaseModel): - """ - Model for the entire model configuration - """ - providers: List[Provider] = Field(..., description="List of available model providers") - defaultProvider: str = Field(..., description="ID of the default provider") - -class AuthorizationConfig(BaseModel): - code: str = Field(..., description="Authorization code") from api.config import configs, WIKI_AUTH_MODE, WIKI_AUTH_CODE diff --git a/api/schemas.py b/api/schemas.py new file mode 100644 index 00000000..882036ee --- /dev/null +++ b/api/schemas.py @@ -0,0 +1,122 @@ +from typing import List, Optional, Dict, Any, Literal +from pydantic import BaseModel, Field + +# Wiki related models +class WikiPage(BaseModel): + """ + Model for a wiki page. + """ + id: str + title: str + content: str + filePaths: List[str] + importance: str # Should ideally be Literal['high', 'medium', 'low'] + relatedPages: List[str] + +class ProcessedProjectEntry(BaseModel): + id: str # Filename + owner: str + repo: str + name: str # owner/repo + repo_type: str # Renamed from type to repo_type for clarity with existing models + submittedAt: int # Timestamp + language: str # Extracted from filename + +class RepoInfo(BaseModel): + owner: str + repo: str + type: str + token: Optional[str] = None + localPath: Optional[str] = None + repoUrl: Optional[str] = None + +class WikiStructureModel(BaseModel): + """ + Model for the overall wiki structure. + """ + id: str + title: str + description: str + pages: List[WikiPage] + +class WikiCacheData(BaseModel): + """ + Model for the data to be stored in the wiki cache. + """ + wiki_structure: WikiStructureModel + generated_pages: Dict[str, WikiPage] + repo_url: Optional[str] = None # compatible for old cache + repo: Optional[RepoInfo] = None + provider: Optional[str] = None + model: Optional[str] = None + +class WikiCacheRequest(BaseModel): + """ + Model for the request body when saving wiki cache. + """ + repo: RepoInfo + language: str + wiki_structure: WikiStructureModel + generated_pages: Dict[str, WikiPage] + provider: str + model: str + +class WikiExportRequest(BaseModel): + """ + Model for requesting a wiki export. + """ + repo_url: str = Field(..., description="URL of the repository") + pages: List[WikiPage] = Field(..., description="List of wiki pages to export") + format: Literal["markdown", "json"] = Field(..., description="Export format (markdown or json)") + +# Model configuration related models +class Model(BaseModel): + """ + Model for LLM model configuration + """ + id: str = Field(..., description="Model identifier") + name: str = Field(..., description="Display name for the model") + +class Provider(BaseModel): + """ + Model for LLM provider configuration + """ + id: str = Field(..., description="Provider identifier") + name: str = Field(..., description="Display name for the provider") + models: List[Model] = Field(..., description="List of available models for this provider") + supportsCustomModel: Optional[bool] = Field(False, description="Whether this provider supports custom models") + +class ModelConfig(BaseModel): + """ + Model for the entire model configuration + """ + providers: List[Provider] = Field(..., description="List of available model providers") + defaultProvider: str = Field(..., description="ID of the default provider") + +class AuthorizationConfig(BaseModel): + code: str = Field(..., description="Authorization code") + +# Chat API models +class ChatMessage(BaseModel): + role: str # 'user' or 'assistant' + content: str + +class ChatCompletionRequest(BaseModel): + """ + Model for requesting a chat completion. + """ + repo_url: str = Field(..., description="URL of the repository to query") + messages: List[ChatMessage] = Field(..., description="List of chat messages") + filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt") + token: Optional[str] = Field(None, description="Personal access token for private repositories") + type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')") + + # model parameters + provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama, bedrock, azure)") + model: Optional[str] = Field(None, description="Model name for the specified provider") + + language: Optional[str] = Field("en", description="Language for content generation (e.g., 'en', 'ja', 'zh', 'es', 'kr', 'vi')") + excluded_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to exclude from processing") + excluded_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to exclude from processing") + included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") + included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") diff --git a/api/simple_chat.py b/api/simple_chat.py index 8fa160f1..1b73d25f 100644 --- a/api/simple_chat.py +++ b/api/simple_chat.py @@ -9,7 +9,7 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field +from api.schemas import ChatMessage, ChatCompletionRequest from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY from api.data_pipeline import count_tokens, get_file_content @@ -41,30 +41,6 @@ allow_headers=["*"], # Allows all headers ) -# Models for the API -class ChatMessage(BaseModel): - role: str # 'user' or 'assistant' - content: str - -class ChatCompletionRequest(BaseModel): - """ - Model for requesting a chat completion. - """ - repo_url: str = Field(..., description="URL of the repository to query") - messages: List[ChatMessage] = Field(..., description="List of chat messages") - filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt") - token: Optional[str] = Field(None, description="Personal access token for private repositories") - type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')") - - # model parameters - provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama, bedrock, azure)") - model: Optional[str] = Field(None, description="Model name for the specified provider") - - language: Optional[str] = Field("en", description="Language for content generation (e.g., 'en', 'ja', 'zh', 'es', 'kr', 'vi')") - excluded_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to exclude from processing") - excluded_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to exclude from processing") - included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") - included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") @app.post("/chat/completions/stream") async def chat_completions_stream(request: ChatCompletionRequest): diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py index 577bfeeb..85a2e84a 100644 --- a/api/websocket_wiki.py +++ b/api/websocket_wiki.py @@ -7,7 +7,7 @@ from adalflow.components.model_client.ollama_client import OllamaClient from adalflow.core.types import ModelType from fastapi import WebSocket, WebSocketDisconnect, HTTPException -from pydantic import BaseModel, Field +from api.schemas import ChatMessage, ChatCompletionRequest from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY from api.data_pipeline import count_tokens, get_file_content @@ -23,30 +23,6 @@ logger = logging.getLogger(__name__) -# Models for the API -class ChatMessage(BaseModel): - role: str # 'user' or 'assistant' - content: str - -class ChatCompletionRequest(BaseModel): - """ - Model for requesting a chat completion. - """ - repo_url: str = Field(..., description="URL of the repository to query") - messages: List[ChatMessage] = Field(..., description="List of chat messages") - filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt") - token: Optional[str] = Field(None, description="Personal access token for private repositories") - type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')") - - # model parameters - provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama, azure)") - model: Optional[str] = Field(None, description="Model name for the specified provider") - - language: Optional[str] = Field("en", description="Language for content generation (e.g., 'en', 'ja', 'zh', 'es', 'kr', 'vi')") - excluded_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to exclude from processing") - excluded_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to exclude from processing") - included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") - included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") async def handle_websocket_chat(websocket: WebSocket): """