diff --git a/README.md b/README.md index 4f145a1b..26d9e6dd 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - **Ask Feature**: Chat with your repository using RAG-powered AI to get accurate answers - **DeepResearch**: Multi-turn research process that thoroughly investigates complex topics - **Multiple Model Providers**: Support for Google Gemini, OpenAI, OpenRouter, and local Ollama models +- **Flexible Embeddings**: Choose between OpenAI, Google AI, or local Ollama embeddings for optimal performance ## ๐Ÿš€ Quick Start (Super Easy!) @@ -39,6 +40,8 @@ cd deepwiki-open # Create a .env file with your API keys echo "GOOGLE_API_KEY=your_google_api_key" > .env echo "OPENAI_API_KEY=your_openai_api_key" >> .env +# Optional: Use Google AI embeddings instead of OpenAI (recommended if using Google models) +echo "DEEPWIKI_EMBEDDER_TYPE=google" >> .env # Optional: Add OpenRouter API key if you want to use OpenRouter models echo "OPENROUTER_API_KEY=your_openrouter_api_key" >> .env # Optional: Add Ollama host if not local. defaults to http://localhost:11434 @@ -67,6 +70,8 @@ Create a `.env` file in the project root with these keys: ``` GOOGLE_API_KEY=your_google_api_key OPENAI_API_KEY=your_openai_api_key +# Optional: Use Google AI embeddings (recommended if using Google models) +DEEPWIKI_EMBEDDER_TYPE=google # Optional: Add this if you want to use OpenRouter models OPENROUTER_API_KEY=your_openrouter_api_key # Optional: Add this if you want to use Azure OpenAI models @@ -269,6 +274,89 @@ If you want to use embedding models compatible with the OpenAI API (such as Alib This allows you to seamlessly switch to any OpenAI-compatible embedding service without code changes. +## ๐Ÿง  Using Google AI Embeddings + +DeepWiki now supports Google AI's latest embedding models as an alternative to OpenAI embeddings. This provides better integration when you're already using Google Gemini models for text generation. + +### Features + +- **Latest Model**: Uses Google's `text-embedding-004` model +- **Same API Key**: Uses your existing `GOOGLE_API_KEY` (no additional setup required) +- **Better Integration**: Optimized for use with Google Gemini text generation models +- **Task-Specific**: Supports semantic similarity, retrieval, and classification tasks +- **Batch Processing**: Efficient processing of multiple texts + +### How to Enable Google AI Embeddings + +**Option 1: Environment Variable (Recommended)** + +Set the embedder type in your `.env` file: + +```bash +# Your existing Google API key +GOOGLE_API_KEY=your_google_api_key + +# Enable Google AI embeddings +DEEPWIKI_EMBEDDER_TYPE=google +``` + +**Option 2: Docker Environment** + +```bash +docker run -p 8001:8001 -p 3000:3000 \ + -e GOOGLE_API_KEY=your_google_api_key \ + -e DEEPWIKI_EMBEDDER_TYPE=google \ + -v ~/.adalflow:/root/.adalflow \ + ghcr.io/asyncfuncai/deepwiki-open:latest +``` + +**Option 3: Docker Compose** + +Add to your `.env` file: + +```bash +GOOGLE_API_KEY=your_google_api_key +DEEPWIKI_EMBEDDER_TYPE=google +``` + +Then run: + +```bash +docker-compose up +``` + +### Available Embedder Types + +| Type | Description | API Key Required | Notes | +|------|-------------|------------------|-------| +| `openai` | OpenAI embeddings (default) | `OPENAI_API_KEY` | Uses `text-embedding-3-small` model | +| `google` | Google AI embeddings | `GOOGLE_API_KEY` | Uses `text-embedding-004` model | +| `ollama` | Local Ollama embeddings | None | Requires local Ollama installation | + +### Why Use Google AI Embeddings? + +- **Consistency**: If you're using Google Gemini for text generation, using Google embeddings provides better semantic consistency +- **Performance**: Google's latest embedding model offers excellent performance for retrieval tasks +- **Cost**: Competitive pricing compared to OpenAI +- **No Additional Setup**: Uses the same API key as your text generation models + +### Switching Between Embedders + +You can easily switch between different embedding providers: + +```bash +# Use OpenAI embeddings (default) +export DEEPWIKI_EMBEDDER_TYPE=openai + +# Use Google AI embeddings +export DEEPWIKI_EMBEDDER_TYPE=google + +# Use local Ollama embeddings +export DEEPWIKI_EMBEDDER_TYPE=ollama +``` + +**Note**: When switching embedders, you may need to regenerate your repository embeddings as different models produce different vector spaces. + ### Logging DeepWiki uses Python's built-in `logging` module for diagnostic output. You can configure the verbosity and log file destination via environment variables: @@ -311,19 +399,25 @@ docker-compose up | Variable | Description | Required | Note | |----------------------|--------------------------------------------------------------|----------|----------------------------------------------------------------------------------------------------------| -| `GOOGLE_API_KEY` | Google Gemini API key for AI generation | No | Required only if you want to use Google Gemini models -| `OPENAI_API_KEY` | OpenAI API key for embeddings | Yes | Note: This is required even if you're not using OpenAI models, as it's used for embeddings. | +| `GOOGLE_API_KEY` | Google Gemini API key for AI generation and embeddings | No | Required for Google Gemini models and Google AI embeddings +| `OPENAI_API_KEY` | OpenAI API key for embeddings and models | Conditional | Required if using OpenAI embeddings or models | | `OPENROUTER_API_KEY` | OpenRouter API key for alternative models | No | Required only if you want to use OpenRouter models | | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key | No | Required only if you want to use Azure OpenAI models | | `AZURE_OPENAI_ENDPOINT` | Azure OpenAI endpoint | No | Required only if you want to use Azure OpenAI models | | `AZURE_OPENAI_VERSION` | Azure OpenAI version | No | Required only if you want to use Azure OpenAI models | | `OLLAMA_HOST` | Ollama Host (default: http://localhost:11434) | No | Required only if you want to use external Ollama server | +| `DEEPWIKI_EMBEDDER_TYPE` | Embedder type: `openai`, `google`, or `ollama` (default: `openai`) | No | Controls which embedding provider to use | | `PORT` | Port for the API server (default: 8001) | No | If you host API and frontend on the same machine, make sure change port of `SERVER_BASE_URL` accordingly | | `SERVER_BASE_URL` | Base URL for the API server (default: http://localhost:8001) | No | | `DEEPWIKI_AUTH_MODE` | Set to `true` or `1` to enable authorization mode. | No | Defaults to `false`. If enabled, `DEEPWIKI_AUTH_CODE` is required. | | `DEEPWIKI_AUTH_CODE` | The secret code required for wiki generation when `DEEPWIKI_AUTH_MODE` is enabled. | No | Only used if `DEEPWIKI_AUTH_MODE` is `true` or `1`. | -If you're not using ollama mode, you need to configure an OpenAI API key for embeddings. Other API keys are only required when configuring and using models from the corresponding providers. +**API Key Requirements:** +- If using `DEEPWIKI_EMBEDDER_TYPE=openai` (default): `OPENAI_API_KEY` is required +- If using `DEEPWIKI_EMBEDDER_TYPE=google`: `GOOGLE_API_KEY` is required +- If using `DEEPWIKI_EMBEDDER_TYPE=ollama`: No API key required (local processing) + +Other API keys are only required when configuring and using models from the corresponding providers. ## Authorization Mode diff --git a/api/api.py b/api/api.py index f3cb65a8..a1433bda 100644 --- a/api/api.py +++ b/api/api.py @@ -507,7 +507,7 @@ async def delete_wiki_cache( if WIKI_AUTH_MODE: logger.info("check the authorization code") - if WIKI_AUTH_CODE != authorization_code: + if not authorization_code or WIKI_AUTH_CODE != authorization_code: raise HTTPException(status_code=401, detail="Authorization code is invalid") logger.info(f"Attempting to delete wiki cache for {owner}/{repo} ({repo_type}), lang: {language}") diff --git a/api/config.py b/api/config.py index 61eba243..f30b53a6 100644 --- a/api/config.py +++ b/api/config.py @@ -10,6 +10,7 @@ from api.openai_client import OpenAIClient from api.openrouter_client import OpenRouterClient from api.bedrock_client import BedrockClient +from api.google_embedder_client import GoogleEmbedderClient from api.azureai_client import AzureAIClient from adalflow import GoogleGenAIClient, OllamaClient @@ -43,12 +44,16 @@ WIKI_AUTH_MODE = raw_auth_mode.lower() in ['true', '1', 't'] WIKI_AUTH_CODE = os.environ.get('DEEPWIKI_AUTH_CODE', '') +# Embedder settings +EMBEDDER_TYPE = os.environ.get('DEEPWIKI_EMBEDDER_TYPE', 'openai').lower() + # Get configuration directory from environment variable, or use default if not set CONFIG_DIR = os.environ.get('DEEPWIKI_CONFIG_DIR', None) # Client class mapping CLIENT_CLASSES = { "GoogleGenAIClient": GoogleGenAIClient, + "GoogleEmbedderClient": GoogleEmbedderClient, "OpenAIClient": OpenAIClient, "OpenRouterClient": OpenRouterClient, "OllamaClient": OllamaClient, @@ -141,7 +146,7 @@ def load_embedder_config(): embedder_config = load_json_config("embedder.json") # Process client classes - for key in ["embedder", "embedder_ollama"]: + for key in ["embedder", "embedder_ollama", "embedder_google"]: if key in embedder_config and "client_class" in embedder_config[key]: class_name = embedder_config[key]["client_class"] if class_name in CLIENT_CLASSES: @@ -151,12 +156,18 @@ def load_embedder_config(): def get_embedder_config(): """ - Get the current embedder configuration. + Get the current embedder configuration based on DEEPWIKI_EMBEDDER_TYPE. Returns: dict: The embedder configuration with model_client resolved """ - return configs.get("embedder", {}) + embedder_type = EMBEDDER_TYPE + if embedder_type == 'google' and 'embedder_google' in configs: + return configs.get("embedder_google", {}) + elif embedder_type == 'ollama' and 'embedder_ollama' in configs: + return configs.get("embedder_ollama", {}) + else: + return configs.get("embedder", {}) def is_ollama_embedder(): """ @@ -178,6 +189,40 @@ def is_ollama_embedder(): client_class = embedder_config.get("client_class", "") return client_class == "OllamaClient" +def is_google_embedder(): + """ + Check if the current embedder configuration uses GoogleEmbedderClient. + + Returns: + bool: True if using GoogleEmbedderClient, False otherwise + """ + embedder_config = get_embedder_config() + if not embedder_config: + return False + + # Check if model_client is GoogleEmbedderClient + model_client = embedder_config.get("model_client") + if model_client: + return model_client.__name__ == "GoogleEmbedderClient" + + # Fallback: check client_class string + client_class = embedder_config.get("client_class", "") + return client_class == "GoogleEmbedderClient" + +def get_embedder_type(): + """ + Get the current embedder type based on configuration. + + Returns: + str: 'ollama', 'google', or 'openai' (default) + """ + if is_ollama_embedder(): + return 'ollama' + elif is_google_embedder(): + return 'google' + else: + return 'openai' + # Load repository and file filters configuration def load_repo_config(): return load_json_config("repo.json") @@ -265,7 +310,7 @@ def load_lang_config(): # Update embedder configuration if embedder_config: - for key in ["embedder", "embedder_ollama", "retriever", "text_splitter"]: + for key in ["embedder", "embedder_ollama", "embedder_google", "retriever", "text_splitter"]: if key in embedder_config: configs[key] = embedder_config[key] diff --git a/api/config/embedder.json b/api/config/embedder.json index a70bdb41..f0ab52d1 100644 --- a/api/config/embedder.json +++ b/api/config/embedder.json @@ -8,6 +8,20 @@ "encoding_format": "float" } }, + "embedder_ollama": { + "client_class": "OllamaClient", + "model_kwargs": { + "model": "nomic-embed-text" + } + }, + "embedder_google": { + "client_class": "GoogleEmbedderClient", + "batch_size": 100, + "model_kwargs": { + "model": "text-embedding-004", + "task_type": "SEMANTIC_SIMILARITY" + } + }, "retriever": { "top_k": 20 }, diff --git a/api/data_pipeline.py b/api/data_pipeline.py index e8634537..f90a176d 100644 --- a/api/data_pipeline.py +++ b/api/data_pipeline.py @@ -25,27 +25,39 @@ # Maximum token limit for OpenAI embedding models MAX_EMBEDDING_TOKENS = 8192 -def count_tokens(text: str, is_ollama_embedder: bool = None) -> int: +def count_tokens(text: str, embedder_type: str = None, is_ollama_embedder: bool = None) -> int: """ Count the number of tokens in a text string using tiktoken. Args: text (str): The text to count tokens for. - is_ollama_embedder (bool, optional): Whether using Ollama embeddings. + embedder_type (str, optional): The embedder type ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. Returns: int: The number of tokens in the text. """ try: - # Determine if using Ollama embedder if not specified - if is_ollama_embedder is None: - from api.config import is_ollama_embedder as check_ollama - is_ollama_embedder = check_ollama() - - if is_ollama_embedder: + # Handle backward compatibility + if embedder_type is None and is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None + + # Determine embedder type if not specified + if embedder_type is None: + from api.config import get_embedder_type + embedder_type = get_embedder_type() + + # Choose encoding based on embedder type + if embedder_type == 'ollama': + # Ollama typically uses cl100k_base encoding encoding = tiktoken.get_encoding("cl100k_base") - else: + elif embedder_type == 'google': + # Google uses similar tokenization to GPT models for rough estimation + encoding = tiktoken.get_encoding("cl100k_base") + else: # OpenAI or default + # Use OpenAI embedding model encoding encoding = tiktoken.encoding_for_model("text-embedding-3-small") return len(encoding.encode(text)) @@ -127,14 +139,17 @@ def download_repo(repo_url: str, local_path: str, type: str = "github", access_t # Alias for backward compatibility download_github_repo = download_repo -def read_all_documents(path: str, is_ollama_embedder: bool = None, excluded_dirs: List[str] = None, excluded_files: List[str] = None, +def read_all_documents(path: str, embedder_type: str = None, is_ollama_embedder: bool = None, + excluded_dirs: List[str] = None, excluded_files: List[str] = None, included_dirs: List[str] = None, included_files: List[str] = None): """ Recursively reads all documents in a directory and its subdirectories. Args: path (str): The root directory path. - is_ollama_embedder (bool, optional): Whether using Ollama embeddings for token counting. + embedder_type (str, optional): The embedder type ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. excluded_dirs (List[str], optional): List of directories to exclude from processing. Overrides the default configuration if provided. @@ -148,6 +163,9 @@ def read_all_documents(path: str, is_ollama_embedder: bool = None, excluded_dirs Returns: list: A list of Document objects with metadata. """ + # Handle backward compatibility + if embedder_type is None and is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None documents = [] # File extensions to look for, prioritizing code files code_extensions = [".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", ".go", ".rs", @@ -293,7 +311,7 @@ def should_process_file(file_path: str, use_inclusion: bool, included_dirs: List ) # Check token count - token_count = count_tokens(content, is_ollama_embedder) + token_count = count_tokens(content, embedder_type) if token_count > MAX_EMBEDDING_TOKENS * 10: logger.warning(f"Skipping large file {relative_path}: Token count ({token_count}) exceeds limit") continue @@ -327,7 +345,7 @@ def should_process_file(file_path: str, use_inclusion: bool, included_dirs: List relative_path = os.path.relpath(file_path, path) # Check token count - token_count = count_tokens(content, is_ollama_embedder) + token_count = count_tokens(content, embedder_type) if token_count > MAX_EMBEDDING_TOKENS: logger.warning(f"Skipping large file {relative_path}: Token count ({token_count}) exceeds limit") continue @@ -350,33 +368,40 @@ def should_process_file(file_path: str, use_inclusion: bool, included_dirs: List logger.info(f"Found {len(documents)} documents") return documents -def prepare_data_pipeline(is_ollama_embedder: bool = None): +def prepare_data_pipeline(embedder_type: str = None, is_ollama_embedder: bool = None): """ Creates and returns the data transformation pipeline. Args: - is_ollama_embedder (bool, optional): Whether to use Ollama for embedding. + embedder_type (str, optional): The embedder type ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. Returns: adal.Sequential: The data transformation pipeline """ - from api.config import get_embedder_config, is_ollama_embedder as check_ollama + from api.config import get_embedder_config, get_embedder_type - # Determine if using Ollama embedder if not specified - if is_ollama_embedder is None: - is_ollama_embedder = check_ollama() + # Handle backward compatibility + if embedder_type is None and is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None + + # Determine embedder type if not specified + if embedder_type is None: + embedder_type = get_embedder_type() splitter = TextSplitter(**configs["text_splitter"]) embedder_config = get_embedder_config() - embedder = get_embedder() + embedder = get_embedder(embedder_type=embedder_type) - if is_ollama_embedder: + # Choose appropriate processor based on embedder type + if embedder_type == 'ollama': # Use Ollama document processor for single-document processing embedder_transformer = OllamaDocumentProcessor(embedder=embedder) else: - # Use batch processing for other embedders + # Use batch processing for OpenAI and Google embedders batch_size = embedder_config.get("batch_size", 500) embedder_transformer = ToEmbeddings( embedder=embedder, batch_size=batch_size @@ -388,7 +413,7 @@ def prepare_data_pipeline(is_ollama_embedder: bool = None): return data_transformer def transform_documents_and_save_to_db( - documents: List[Document], db_path: str, is_ollama_embedder: bool = None + documents: List[Document], db_path: str, embedder_type: str = None, is_ollama_embedder: bool = None ) -> LocalDB: """ Transforms a list of documents and saves them to a local database. @@ -396,11 +421,13 @@ def transform_documents_and_save_to_db( Args: documents (list): A list of `Document` objects. db_path (str): The path to the local database file. - is_ollama_embedder (bool, optional): Whether to use Ollama for embedding. + embedder_type (str, optional): The embedder type ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. """ # Get the data transformer - data_transformer = prepare_data_pipeline(is_ollama_embedder) + data_transformer = prepare_data_pipeline(embedder_type, is_ollama_embedder) # Save the documents to a local database db = LocalDB() @@ -631,7 +658,8 @@ def __init__(self): self.repo_url_or_path = None self.repo_paths = None - def prepare_database(self, repo_url_or_path: str, type: str = "github", access_token: str = None, is_ollama_embedder: bool = None, + def prepare_database(self, repo_url_or_path: str, type: str = "github", access_token: str = None, + embedder_type: str = None, is_ollama_embedder: bool = None, excluded_dirs: List[str] = None, excluded_files: List[str] = None, included_dirs: List[str] = None, included_files: List[str] = None) -> List[Document]: """ @@ -640,7 +668,9 @@ def prepare_database(self, repo_url_or_path: str, type: str = "github", access_t Args: repo_url_or_path (str): The URL or local path of the repository access_token (str, optional): Access token for private repositories - is_ollama_embedder (bool, optional): Whether to use Ollama for embedding. + embedder_type (str, optional): Embedder type to use ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. excluded_dirs (List[str], optional): List of directories to exclude from processing excluded_files (List[str], optional): List of file patterns to exclude from processing @@ -650,9 +680,13 @@ def prepare_database(self, repo_url_or_path: str, type: str = "github", access_t Returns: List[Document]: List of Document objects """ + # Handle backward compatibility + if embedder_type is None and is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None + self.reset_database() self._create_repo(repo_url_or_path, type, access_token) - return self.prepare_db_index(is_ollama_embedder=is_ollama_embedder, excluded_dirs=excluded_dirs, excluded_files=excluded_files, + return self.prepare_db_index(embedder_type=embedder_type, excluded_dirs=excluded_dirs, excluded_files=excluded_files, included_dirs=included_dirs, included_files=included_files) def reset_database(self): @@ -728,13 +762,16 @@ def _create_repo(self, repo_url_or_path: str, repo_type: str = "github", access_ logger.error(f"Failed to create repository structure: {e}") raise - def prepare_db_index(self, is_ollama_embedder: bool = None, excluded_dirs: List[str] = None, excluded_files: List[str] = None, + def prepare_db_index(self, embedder_type: str = None, is_ollama_embedder: bool = None, + excluded_dirs: List[str] = None, excluded_files: List[str] = None, included_dirs: List[str] = None, included_files: List[str] = None) -> List[Document]: """ Prepare the indexed database for the repository. Args: - is_ollama_embedder (bool, optional): Whether to use Ollama for embedding. + embedder_type (str, optional): Embedder type to use ('openai', 'google', 'ollama'). + If None, will be determined from configuration. + is_ollama_embedder (bool, optional): DEPRECATED. Use embedder_type instead. If None, will be determined from configuration. excluded_dirs (List[str], optional): List of directories to exclude from processing excluded_files (List[str], optional): List of file patterns to exclude from processing @@ -744,6 +781,9 @@ def prepare_db_index(self, is_ollama_embedder: bool = None, excluded_dirs: List[ Returns: List[Document]: List of Document objects """ + # Handle backward compatibility + if embedder_type is None and is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None # check the database if self.repo_paths and os.path.exists(self.repo_paths["save_db_file"]): logger.info("Loading existing database...") @@ -761,14 +801,14 @@ def prepare_db_index(self, is_ollama_embedder: bool = None, excluded_dirs: List[ logger.info("Creating new database...") documents = read_all_documents( self.repo_paths["save_repo_dir"], - is_ollama_embedder=is_ollama_embedder, + embedder_type=embedder_type, excluded_dirs=excluded_dirs, excluded_files=excluded_files, included_dirs=included_dirs, included_files=included_files ) self.db = transform_documents_and_save_to_db( - documents, self.repo_paths["save_db_file"], is_ollama_embedder=is_ollama_embedder + documents, self.repo_paths["save_db_file"], embedder_type=embedder_type ) logger.info(f"Total documents: {len(documents)}") transformed_docs = self.db.get_transformed_data(key="split_and_embed") diff --git a/api/google_embedder_client.py b/api/google_embedder_client.py new file mode 100644 index 00000000..b604fd8e --- /dev/null +++ b/api/google_embedder_client.py @@ -0,0 +1,231 @@ +"""Google AI Embeddings ModelClient integration.""" + +import os +import logging +import backoff +from typing import Dict, Any, Optional, List, Sequence + +from adalflow.core.model_client import ModelClient +from adalflow.core.types import ModelType, EmbedderOutput + +try: + import google.generativeai as genai + from google.generativeai.types.text_types import EmbeddingDict, BatchEmbeddingDict +except ImportError: + raise ImportError("google-generativeai is required. Install it with 'pip install google-generativeai'") + +log = logging.getLogger(__name__) + + +class GoogleEmbedderClient(ModelClient): + __doc__ = r"""A component wrapper for Google AI Embeddings API client. + + This client provides access to Google's embedding models through the Google AI API. + It supports text embeddings for various tasks including semantic similarity, + retrieval, and classification. + + Args: + api_key (Optional[str]): Google AI API key. Defaults to None. + If not provided, will use the GOOGLE_API_KEY environment variable. + env_api_key_name (str): Environment variable name for the API key. + Defaults to "GOOGLE_API_KEY". + + Example: + ```python + from api.google_embedder_client import GoogleEmbedderClient + import adalflow as adal + + client = GoogleEmbedderClient() + embedder = adal.Embedder( + model_client=client, + model_kwargs={ + "model": "text-embedding-004", + "task_type": "SEMANTIC_SIMILARITY" + } + ) + ``` + + References: + - Google AI Embeddings: https://ai.google.dev/gemini-api/docs/embeddings + - Available models: text-embedding-004, embedding-001 + """ + + def __init__( + self, + api_key: Optional[str] = None, + env_api_key_name: str = "GOOGLE_API_KEY", + ): + """Initialize Google AI Embeddings client. + + Args: + api_key: Google AI API key. If not provided, uses environment variable. + env_api_key_name: Name of environment variable containing API key. + """ + super().__init__() + self._api_key = api_key + self._env_api_key_name = env_api_key_name + self._initialize_client() + + def _initialize_client(self): + """Initialize the Google AI client with API key.""" + api_key = self._api_key or os.getenv(self._env_api_key_name) + if not api_key: + raise ValueError( + f"Environment variable {self._env_api_key_name} must be set" + ) + genai.configure(api_key=api_key) + + def parse_embedding_response(self, response) -> EmbedderOutput: + """Parse Google AI embedding response to EmbedderOutput format. + + Args: + response: Google AI embedding response (EmbeddingDict or BatchEmbeddingDict) + + Returns: + EmbedderOutput with parsed embeddings + """ + try: + from adalflow.core.types import Embedding + + embedding_data = [] + + if isinstance(response, dict): + if 'embedding' in response: + embedding_value = response['embedding'] + if isinstance(embedding_value, list) and len(embedding_value) > 0: + # Check if it's a single embedding (list of floats) or batch (list of lists) + if isinstance(embedding_value[0], (int, float)): + # Single embedding response: {'embedding': [float, ...]} + embedding_data = [Embedding(embedding=embedding_value, index=0)] + else: + # Batch embeddings response: {'embedding': [[float, ...], [float, ...], ...]} + embedding_data = [ + Embedding(embedding=emb_list, index=i) + for i, emb_list in enumerate(embedding_value) + ] + else: + log.warning(f"Empty or invalid embedding data: {embedding_value}") + embedding_data = [] + elif 'embeddings' in response: + # Alternative batch format: {'embeddings': [{'embedding': [float, ...]}, ...]} + embedding_data = [ + Embedding(embedding=item['embedding'], index=i) + for i, item in enumerate(response['embeddings']) + ] + else: + log.warning(f"Unexpected response structure: {response.keys()}") + embedding_data = [] + elif hasattr(response, 'embeddings'): + # Custom batch response object from our implementation + embedding_data = [ + Embedding(embedding=emb, index=i) + for i, emb in enumerate(response.embeddings) + ] + else: + log.warning(f"Unexpected response type: {type(response)}") + embedding_data = [] + + return EmbedderOutput( + data=embedding_data, + error=None, + raw_response=response + ) + except Exception as e: + log.error(f"Error parsing Google AI embedding response: {e}") + return EmbedderOutput( + data=[], + error=str(e), + raw_response=response + ) + + def convert_inputs_to_api_kwargs( + self, + input: Optional[Any] = None, + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + ) -> Dict: + """Convert inputs to Google AI API format. + + Args: + input: Text input(s) to embed + model_kwargs: Model parameters including model name and task_type + model_type: Should be ModelType.EMBEDDER for this client + + Returns: + Dict: API kwargs for Google AI embedding call + """ + if model_type != ModelType.EMBEDDER: + raise ValueError(f"GoogleEmbedderClient only supports EMBEDDER model type, got {model_type}") + + # Ensure input is a list + if isinstance(input, str): + content = [input] + elif isinstance(input, Sequence): + content = list(input) + else: + raise TypeError("input must be a string or sequence of strings") + + final_model_kwargs = model_kwargs.copy() + + # Handle single vs batch embedding + if len(content) == 1: + final_model_kwargs["content"] = content[0] + else: + final_model_kwargs["contents"] = content + + # Set default task type if not provided + if "task_type" not in final_model_kwargs: + final_model_kwargs["task_type"] = "SEMANTIC_SIMILARITY" + + # Set default model if not provided + if "model" not in final_model_kwargs: + final_model_kwargs["model"] = "text-embedding-004" + + return final_model_kwargs + + @backoff.on_exception( + backoff.expo, + (Exception,), # Google AI may raise various exceptions + max_time=5, + ) + def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + """Call Google AI embedding API. + + Args: + api_kwargs: API parameters + model_type: Should be ModelType.EMBEDDER + + Returns: + Google AI embedding response + """ + if model_type != ModelType.EMBEDDER: + raise ValueError(f"GoogleEmbedderClient only supports EMBEDDER model type") + + log.info(f"Google AI Embeddings API kwargs: {api_kwargs}") + + try: + # Use embed_content for single text or batch embedding + if "content" in api_kwargs: + # Single embedding + response = genai.embed_content(**api_kwargs) + elif "contents" in api_kwargs: + # Batch embedding - Google AI supports batch natively + contents = api_kwargs.pop("contents") + response = genai.embed_content(content=contents, **api_kwargs) + else: + raise ValueError("Either 'content' or 'contents' must be provided") + + return response + + except Exception as e: + log.error(f"Error calling Google AI Embeddings API: {e}") + raise + + async def acall(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + """Async call to Google AI embedding API. + + Note: Google AI Python client doesn't have async support yet, + so this falls back to synchronous call. + """ + # Google AI client doesn't have async support yet + return self.call(api_kwargs, model_type) \ No newline at end of file diff --git a/api/main.py b/api/main.py index a1989261..791e31b7 100644 --- a/api/main.py +++ b/api/main.py @@ -1,4 +1,3 @@ -import uvicorn import os import sys import logging @@ -13,9 +12,38 @@ setup_logging() logger = logging.getLogger(__name__) +# Configure watchfiles logger to show file paths +watchfiles_logger = logging.getLogger("watchfiles.main") +watchfiles_logger.setLevel(logging.DEBUG) # Enable DEBUG to see file paths + # Add the current directory to the path so we can import the api package sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# Apply watchfiles monkey patch BEFORE uvicorn import +is_development = os.environ.get("NODE_ENV") != "production" +if is_development: + import watchfiles + current_dir = os.path.dirname(os.path.abspath(__file__)) + logs_dir = os.path.join(current_dir, "logs") + + original_watch = watchfiles.watch + def patched_watch(*args, **kwargs): + # Only watch the api directory but exclude logs subdirectory + # Instead of watching the entire api directory, watch specific subdirectories + api_subdirs = [] + for item in os.listdir(current_dir): + item_path = os.path.join(current_dir, item) + if os.path.isdir(item_path) and item != "logs": + api_subdirs.append(item_path) + + # Also add Python files in the api root directory + api_subdirs.append(current_dir + "/*.py") + + return original_watch(*api_subdirs, **kwargs) + watchfiles.watch = patched_watch + +import uvicorn + # Check for required environment variables required_env_vars = ['GOOGLE_API_KEY', 'OPENAI_API_KEY'] missing_vars = [var for var in required_env_vars if not os.environ.get(var)] @@ -42,16 +70,10 @@ logger.info(f"Starting Streaming API on port {port}") # Run the FastAPI app with uvicorn - # Disable reload in production/Docker environment - is_development = os.environ.get("NODE_ENV") != "production" - - if is_development: - # Prevent infinite logging loop caused by file changes triggering log writes - logging.getLogger("watchfiles.main").setLevel(logging.WARNING) - uvicorn.run( "api.api:app", host="0.0.0.0", port=port, - reload=is_development + reload=is_development, + reload_excludes=["**/logs/*", "**/__pycache__/*", "**/*.pyc"] if is_development else None, ) diff --git a/api/rag.py b/api/rag.py index 59142463..fa014b9d 100644 --- a/api/rag.py +++ b/api/rag.py @@ -222,10 +222,11 @@ def __init__(self, provider="google", model=None, use_s3: bool = False): # noqa self.model = model # Import the helper functions - from api.config import get_embedder_config, is_ollama_embedder + from api.config import get_embedder_config, get_embedder_type - # Determine if we're using Ollama embedder based on configuration - self.is_ollama_embedder = is_ollama_embedder() + # Determine embedder type based on current configuration + self.embedder_type = get_embedder_type() + self.is_ollama_embedder = (self.embedder_type == 'ollama') # Backward compatibility # Check if Ollama model exists before proceeding if self.is_ollama_embedder: @@ -240,7 +241,7 @@ def __init__(self, provider="google", model=None, use_s3: bool = False): # noqa # Initialize components self.memory = Memory() - self.embedder = get_embedder() + self.embedder = get_embedder(embedder_type=self.embedder_type) # Patch: ensure query embedding is always single string for Ollama def single_string_embedder(query): @@ -412,7 +413,7 @@ def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_ repo_url_or_path, type, access_token, - is_ollama_embedder=self.is_ollama_embedder, + embedder_type=self.embedder_type, excluded_dirs=excluded_dirs, excluded_files=excluded_files, included_dirs=included_dirs, diff --git a/api/tools/embedder.py b/api/tools/embedder.py index cb2eb51f..fcdab3d3 100644 --- a/api/tools/embedder.py +++ b/api/tools/embedder.py @@ -1,10 +1,40 @@ import adalflow as adal -from api.config import configs +from api.config import configs, get_embedder_type -def get_embedder() -> adal.Embedder: - embedder_config = configs["embedder"] +def get_embedder(is_local_ollama: bool = False, use_google_embedder: bool = False, embedder_type: str = None) -> adal.Embedder: + """Get embedder based on configuration or parameters. + + Args: + is_local_ollama: Legacy parameter for Ollama embedder + use_google_embedder: Legacy parameter for Google embedder + embedder_type: Direct specification of embedder type ('ollama', 'google', 'openai') + + Returns: + adal.Embedder: Configured embedder instance + """ + # Determine which embedder config to use + if embedder_type: + if embedder_type == 'ollama': + embedder_config = configs["embedder_ollama"] + elif embedder_type == 'google': + embedder_config = configs["embedder_google"] + else: # default to openai + embedder_config = configs["embedder"] + elif is_local_ollama: + embedder_config = configs["embedder_ollama"] + elif use_google_embedder: + embedder_config = configs["embedder_google"] + else: + # Auto-detect based on current configuration + current_type = get_embedder_type() + if current_type == 'ollama': + embedder_config = configs["embedder_ollama"] + elif current_type == 'google': + embedder_config = configs["embedder_google"] + else: + embedder_config = configs["embedder"] # --- Initialize Embedder --- model_client_class = embedder_config["model_client"] @@ -12,8 +42,13 @@ def get_embedder() -> adal.Embedder: model_client = model_client_class(**embedder_config["initialize_kwargs"]) else: model_client = model_client_class() - embedder = adal.Embedder( - model_client=model_client, - model_kwargs=embedder_config["model_kwargs"], - ) + + # Create embedder with basic parameters + embedder_kwargs = {"model_client": model_client, "model_kwargs": embedder_config["model_kwargs"]} + + embedder = adal.Embedder(**embedder_kwargs) + + # Set batch_size as an attribute if available (not a constructor parameter) + if "batch_size" in embedder_config: + embedder.batch_size = embedder_config["batch_size"] return embedder diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..5d3109a7 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,126 @@ +# DeepWiki Tests + +This directory contains all tests for the DeepWiki project, organized by type and scope. + +## Directory Structure + +``` +tests/ +โ”œโ”€โ”€ unit/ # Unit tests - test individual components in isolation +โ”‚ โ”œโ”€โ”€ test_google_embedder.py # Tests for Google AI embedder client +โ”‚ โ””โ”€โ”€ test_google_embedder_fix.py # Tests for embedding response parsing fix +โ”œโ”€โ”€ integration/ # Integration tests - test component interactions +โ”‚ โ””โ”€โ”€ test_full_integration.py # Full pipeline integration test +โ”œโ”€โ”€ api/ # API tests - test HTTP endpoints +โ”‚ โ””โ”€โ”€ test_api.py # API endpoint tests +โ””โ”€โ”€ run_tests.py # Test runner script +``` + +## Running Tests + +### All Tests +```bash +python tests/run_tests.py +``` + +### Unit Tests Only +```bash +python tests/run_tests.py --unit +``` + +### Integration Tests Only +```bash +python tests/run_tests.py --integration +``` + +### API Tests Only +```bash +python tests/run_tests.py --api +``` + +### Individual Test Files +```bash +# Unit tests +python tests/unit/test_google_embedder.py +python tests/unit/test_google_embedder_fix.py + +# Integration tests +python tests/integration/test_full_integration.py + +# API tests +python tests/api/test_api.py +``` + +## Test Requirements + +### Environment Variables +- `GOOGLE_API_KEY`: Required for Google AI embedder tests +- `OPENAI_API_KEY`: Required for some integration tests +- `DEEPWIKI_EMBEDDER_TYPE`: Set to 'google' for Google embedder tests + +### Dependencies +All test dependencies are included in the main project requirements: +- `python-dotenv`: For loading environment variables +- `adalflow`: Core framework for embeddings +- `google-generativeai`: Google AI API client +- `requests`: For API testing + +## Test Categories + +### Unit Tests +- **Purpose**: Test individual components in isolation +- **Speed**: Fast (< 1 second per test) +- **Dependencies**: Minimal external dependencies +- **Examples**: Testing embedder response parsing, configuration loading + +### Integration Tests +- **Purpose**: Test how components work together +- **Speed**: Medium (1-10 seconds per test) +- **Dependencies**: May require API keys and external services +- **Examples**: End-to-end embedding pipeline, RAG workflow + +### API Tests +- **Purpose**: Test HTTP endpoints and WebSocket connections +- **Speed**: Medium-slow (5-30 seconds per test) +- **Dependencies**: Requires running API server +- **Examples**: Chat completion endpoints, streaming responses + +## Adding New Tests + +1. **Choose the right category**: Determine if your test is unit, integration, or API +2. **Create the test file**: Place it in the appropriate subdirectory +3. **Follow naming convention**: `test_.py` +4. **Add proper imports**: Use the project root path setup pattern +5. **Document the test**: Add docstrings explaining what the test does +6. **Update this README**: Add your test to the appropriate section + +## Troubleshooting + +### Import Errors +If you get import errors, ensure the test file includes the project root path setup: + +```python +from pathlib import Path +import sys + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) +``` + +### API Key Issues +Make sure you have a `.env` file in the project root with the required API keys: + +``` +GOOGLE_API_KEY=your_google_api_key_here +OPENAI_API_KEY=your_openai_api_key_here +DEEPWIKI_EMBEDDER_TYPE=google +``` + +### Server Dependencies +For API tests, ensure the FastAPI server is running on the expected port: + +```bash +cd api +python main.py +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..61e31165 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests for DeepWiki \ No newline at end of file diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..3cd837b0 --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +# API tests \ No newline at end of file diff --git a/api/test_api.py b/tests/api/test_api.py similarity index 100% rename from api/test_api.py rename to tests/api/test_api.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..11b88fa6 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +# Integration tests \ No newline at end of file diff --git a/tests/integration/test_full_integration.py b/tests/integration/test_full_integration.py new file mode 100644 index 00000000..21c85e3e --- /dev/null +++ b/tests/integration/test_full_integration.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Full integration test for Google AI embeddings.""" + +import os +import sys +import json +from pathlib import Path + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +def test_config_loading(): + """Test that configurations load properly.""" + print("๐Ÿ”ง Testing configuration loading...") + + try: + from api.config import configs, CLIENT_CLASSES + + # Check if Google embedder config exists + if 'embedder_google' in configs: + print("โœ… embedder_google configuration found") + google_config = configs['embedder_google'] + print(f"๐Ÿ“‹ Google config: {json.dumps(google_config, indent=2, default=str)}") + else: + print("โŒ embedder_google configuration not found") + return False + + # Check if GoogleEmbedderClient is in CLIENT_CLASSES + if 'GoogleEmbedderClient' in CLIENT_CLASSES: + print("โœ… GoogleEmbedderClient found in CLIENT_CLASSES") + else: + print("โŒ GoogleEmbedderClient not found in CLIENT_CLASSES") + return False + + return True + + except Exception as e: + print(f"โŒ Error loading configuration: {e}") + import traceback + traceback.print_exc() + return False + +def test_embedder_selection(): + """Test embedder selection mechanism.""" + print("\n๐Ÿ”ง Testing embedder selection...") + + try: + from api.tools.embedder import get_embedder + from api.config import get_embedder_type, is_google_embedder + + # Test default embedder type + current_type = get_embedder_type() + print(f"๐Ÿ“‹ Current embedder type: {current_type}") + + # Test is_google_embedder function + is_google = is_google_embedder() + print(f"๐Ÿ“‹ Is Google embedder: {is_google}") + + # Test get_embedder with google type + print("๐Ÿงช Testing get_embedder with embedder_type='google'...") + embedder = get_embedder(embedder_type='google') + print(f"โœ… Google embedder created: {type(embedder)}") + + return True + + except Exception as e: + print(f"โŒ Error testing embedder selection: {e}") + import traceback + traceback.print_exc() + return False + +def test_google_embedder_with_env(): + """Test Google embedder with environment variable.""" + print("\n๐Ÿ”ง Testing with DEEPWIKI_EMBEDDER_TYPE=google...") + + # Set environment variable + original_value = os.environ.get('DEEPWIKI_EMBEDDER_TYPE') + os.environ['DEEPWIKI_EMBEDDER_TYPE'] = 'google' + + try: + # Reload config module to pick up new env var + import importlib + import api.config + importlib.reload(api.config) + + from api.config import EMBEDDER_TYPE, get_embedder_type, get_embedder_config + from api.tools.embedder import get_embedder + + print(f"๐Ÿ“‹ EMBEDDER_TYPE: {EMBEDDER_TYPE}") + print(f"๐Ÿ“‹ get_embedder_type(): {get_embedder_type()}") + + # Test getting embedder config + config = get_embedder_config() + print(f"๐Ÿ“‹ Current embedder config client: {config.get('client_class', 'Unknown')}") + + # Test creating embedder + embedder = get_embedder() + print(f"โœ… Embedder created with google env var: {type(embedder)}") + + return True + + except Exception as e: + print(f"โŒ Error testing with environment variable: {e}") + import traceback + traceback.print_exc() + return False + + finally: + # Restore original environment variable + if original_value is not None: + os.environ['DEEPWIKI_EMBEDDER_TYPE'] = original_value + elif 'DEEPWIKI_EMBEDDER_TYPE' in os.environ: + del os.environ['DEEPWIKI_EMBEDDER_TYPE'] + +def main(): + """Run all integration tests.""" + print("๐Ÿš€ Starting Google AI Embeddings Integration Tests") + print("=" * 60) + + tests = [ + test_config_loading, + test_embedder_selection, + test_google_embedder_with_env, + ] + + passed = 0 + total = len(tests) + + for test in tests: + try: + if test(): + passed += 1 + print("โœ… PASSED") + else: + print("โŒ FAILED") + except Exception as e: + print(f"โŒ FAILED with exception: {e}") + print("-" * 40) + + print(f"\n๐Ÿ“Š Test Results: {passed}/{total} tests passed") + + if passed == total: + print("๐ŸŽ‰ All integration tests passed!") + return True + else: + print("๐Ÿ’ฅ Some tests failed!") + return False + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/run_tests.py b/tests/run_tests.py new file mode 100644 index 00000000..9c93ae59 --- /dev/null +++ b/tests/run_tests.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +""" +Test runner for DeepWiki project. + +This script provides a unified way to run all tests or specific test categories. +""" + +import os +import sys +import argparse +import subprocess +from pathlib import Path + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +def run_test_file(test_file): + """Run a single test file and return success status.""" + print(f"\n๐Ÿงช Running {test_file}...") + try: + result = subprocess.run([sys.executable, str(test_file)], + capture_output=True, text=True, cwd=project_root) + + if result.returncode == 0: + print(f"โœ… {test_file.name} - PASSED") + if result.stdout: + print(f"๐Ÿ“„ Output:\n{result.stdout}") + return True + else: + print(f"โŒ {test_file.name} - FAILED") + if result.stderr: + print(f"๐Ÿ’ฅ Error:\n{result.stderr}") + if result.stdout: + print(f"๐Ÿ“„ Output:\n{result.stdout}") + return False + except Exception as e: + print(f"๐Ÿ’ฅ {test_file.name} - ERROR: {e}") + return False + +def run_tests(test_dirs): + """Run all tests in the specified directories.""" + total_tests = 0 + passed_tests = 0 + failed_tests = [] + + for test_dir in test_dirs: + test_path = Path(__file__).parent / test_dir + if not test_path.exists(): + print(f"โš ๏ธ Warning: Test directory {test_dir} not found") + continue + + test_files = list(test_path.glob("test_*.py")) + if not test_files: + print(f"โš ๏ธ No test files found in {test_dir}") + continue + + print(f"\n๐Ÿ“ Running {test_dir} tests...") + for test_file in sorted(test_files): + total_tests += 1 + if run_test_file(test_file): + passed_tests += 1 + else: + failed_tests.append(str(test_file)) + + # Print summary + print(f"\n{'='*50}") + print(f"๐Ÿ“Š TEST SUMMARY") + print(f"{'='*50}") + print(f"Total tests: {total_tests}") + print(f"Passed: {passed_tests}") + print(f"Failed: {len(failed_tests)}") + + if failed_tests: + print(f"\nโŒ Failed tests:") + for test in failed_tests: + print(f" - {test}") + print(f"\n๐Ÿ’ก Tip: Run individual failed tests for more details") + return False + else: + print(f"\n๐ŸŽ‰ All tests passed!") + return True + +def check_environment(): + """Check if required environment variables and dependencies are available.""" + print("๐Ÿ”ง Checking test environment...") + + # Check for .env file + env_file = project_root / ".env" + if env_file.exists(): + print("โœ… .env file found") + from dotenv import load_dotenv + load_dotenv(env_file) + else: + print("โš ๏ธ No .env file found - some tests may fail without API keys") + + # Check for API keys + api_keys = { + "GOOGLE_API_KEY": "Google AI embedder tests", + "OPENAI_API_KEY": "OpenAI integration tests" + } + + for key, purpose in api_keys.items(): + if os.getenv(key): + print(f"โœ… {key} is set ({purpose})") + else: + print(f"โš ๏ธ {key} not set - {purpose} may fail") + + # Check Python dependencies + try: + import adalflow + print("โœ… adalflow available") + except ImportError: + print("โŒ adalflow not available - install with: pip install adalflow") + + try: + import google.generativeai + print("โœ… google-generativeai available") + except ImportError: + print("โŒ google-generativeai not available - install with: pip install google-generativeai") + + try: + import requests + print("โœ… requests available") + except ImportError: + print("โŒ requests not available - install with: pip install requests") + +def main(): + parser = argparse.ArgumentParser(description="Run DeepWiki tests") + parser.add_argument("--unit", action="store_true", help="Run only unit tests") + parser.add_argument("--integration", action="store_true", help="Run only integration tests") + parser.add_argument("--api", action="store_true", help="Run only API tests") + parser.add_argument("--check-env", action="store_true", help="Only check environment setup") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + + args = parser.parse_args() + + # Check environment first + check_environment() + + if args.check_env: + return + + # Determine which tests to run + test_dirs = [] + if args.unit: + test_dirs.append("unit") + if args.integration: + test_dirs.append("integration") + if args.api: + test_dirs.append("api") + + # If no specific category selected, run all + if not test_dirs: + test_dirs = ["unit", "integration", "api"] + + print(f"\n๐Ÿš€ Starting test run for: {', '.join(test_dirs)}") + + success = run_tests(test_dirs) + sys.exit(0 if success else 1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..4d46ee58 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +# Unit tests \ No newline at end of file diff --git a/tests/unit/test_all_embedders.py b/tests/unit/test_all_embedders.py new file mode 100644 index 00000000..416373df --- /dev/null +++ b/tests/unit/test_all_embedders.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 +""" +Comprehensive test suite for all embedder types (OpenAI, Google, Ollama). +This test file validates the embedder system before any modifications are made. +""" + +import os +import sys +import logging +from pathlib import Path +from unittest.mock import patch, MagicMock + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +# Set up environment +from dotenv import load_dotenv +load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Simple test framework without pytest +class TestRunner: + def __init__(self): + self.tests_run = 0 + self.tests_passed = 0 + self.tests_failed = 0 + self.failures = [] + + def run_test(self, test_func, test_name=None): + """Run a single test function.""" + if test_name is None: + test_name = test_func.__name__ + + self.tests_run += 1 + try: + logger.info(f"Running test: {test_name}") + test_func() + self.tests_passed += 1 + logger.info(f"โœ… {test_name} PASSED") + return True + except Exception as e: + self.tests_failed += 1 + self.failures.append((test_name, str(e))) + logger.error(f"โŒ {test_name} FAILED: {e}") + return False + + def run_test_class(self, test_class): + """Run all test methods in a test class.""" + instance = test_class() + test_methods = [getattr(instance, method) for method in dir(instance) + if method.startswith('test_') and callable(getattr(instance, method))] + + for test_method in test_methods: + test_name = f"{test_class.__name__}.{test_method.__name__}" + self.run_test(test_method, test_name) + + def run_parametrized_test(self, test_func, parameters, test_name_base=None): + """Run a test function with multiple parameter sets.""" + if test_name_base is None: + test_name_base = test_func.__name__ + + for i, param in enumerate(parameters): + test_name = f"{test_name_base}[{param}]" + self.run_test(lambda: test_func(param), test_name) + + def summary(self): + """Print test summary.""" + logger.info(f"\n๐Ÿ“Š Test Summary:") + logger.info(f"Tests run: {self.tests_run}") + logger.info(f"Passed: {self.tests_passed}") + logger.info(f"Failed: {self.tests_failed}") + + if self.failures: + logger.error("\nโŒ Failed tests:") + for test_name, error in self.failures: + logger.error(f" - {test_name}: {error}") + + return self.tests_failed == 0 + +class TestEmbedderConfiguration: + """Test embedder configuration system.""" + + def test_config_loading(self): + """Test that all embedder configurations load properly.""" + from api.config import configs, CLIENT_CLASSES + + # Check all embedder configurations exist + assert 'embedder' in configs, "OpenAI embedder config missing" + assert 'embedder_google' in configs, "Google embedder config missing" + assert 'embedder_ollama' in configs, "Ollama embedder config missing" + + # Check client classes are available + assert 'OpenAIClient' in CLIENT_CLASSES, "OpenAIClient missing from CLIENT_CLASSES" + assert 'GoogleEmbedderClient' in CLIENT_CLASSES, "GoogleEmbedderClient missing from CLIENT_CLASSES" + assert 'OllamaClient' in CLIENT_CLASSES, "OllamaClient missing from CLIENT_CLASSES" + + def test_embedder_type_detection(self): + """Test embedder type detection functions.""" + from api.config import get_embedder_type, is_ollama_embedder, is_google_embedder + + # Default type should be detected + current_type = get_embedder_type() + assert current_type in ['openai', 'google', 'ollama'], f"Invalid embedder type: {current_type}" + + # Boolean functions should work + is_ollama = is_ollama_embedder() + is_google = is_google_embedder() + assert isinstance(is_ollama, bool), "is_ollama_embedder should return boolean" + assert isinstance(is_google, bool), "is_google_embedder should return boolean" + + # Only one should be true at a time (unless using openai default) + if current_type == 'ollama': + assert is_ollama and not is_google + elif current_type == 'google': + assert not is_ollama and is_google + else: # openai + assert not is_ollama and not is_google + + def test_get_embedder_config(self, embedder_type=None): + """Test getting embedder config for each type.""" + from api.config import get_embedder_config + + if embedder_type: + # Mock the EMBEDDER_TYPE for testing + with patch('api.config.EMBEDDER_TYPE', embedder_type): + config = get_embedder_config() + assert isinstance(config, dict), f"Config for {embedder_type} should be dict" + assert 'model_client' in config or 'client_class' in config, f"No client specified for {embedder_type}" + else: + # Test current configuration + config = get_embedder_config() + assert isinstance(config, dict), "Config should be dict" + assert 'model_client' in config or 'client_class' in config, "No client specified" + + +class TestEmbedderFactory: + """Test the embedder factory function.""" + + def test_get_embedder_with_explicit_type(self): + """Test get_embedder with explicit embedder_type parameter.""" + from api.tools.embedder import get_embedder + + # Test Google embedder + google_embedder = get_embedder(embedder_type='google') + assert google_embedder is not None, "Google embedder should be created" + + # Test OpenAI embedder + openai_embedder = get_embedder(embedder_type='openai') + assert openai_embedder is not None, "OpenAI embedder should be created" + + # Test Ollama embedder (may fail if Ollama not available, but should not crash) + try: + ollama_embedder = get_embedder(embedder_type='ollama') + assert ollama_embedder is not None, "Ollama embedder should be created" + except Exception as e: + logger.warning(f"Ollama embedder creation failed (expected if Ollama not available): {e}") + + def test_get_embedder_with_legacy_params(self): + """Test get_embedder with legacy boolean parameters.""" + from api.tools.embedder import get_embedder + + # Test with use_google_embedder=True + google_embedder = get_embedder(use_google_embedder=True) + assert google_embedder is not None, "Google embedder should be created with use_google_embedder=True" + + # Test with is_local_ollama=True + try: + ollama_embedder = get_embedder(is_local_ollama=True) + assert ollama_embedder is not None, "Ollama embedder should be created with is_local_ollama=True" + except Exception as e: + logger.warning(f"Ollama embedder creation failed (expected if Ollama not available): {e}") + + def test_get_embedder_auto_detection(self): + """Test get_embedder with automatic type detection.""" + from api.tools.embedder import get_embedder + + # Test auto-detection (should use current configuration) + embedder = get_embedder() + assert embedder is not None, "Auto-detected embedder should be created" + + +class TestEmbedderClients: + """Test individual embedder clients.""" + + def test_google_embedder_client(self): + """Test Google embedder client directly.""" + if not os.getenv('GOOGLE_API_KEY'): + logger.warning("Skipping Google embedder test - GOOGLE_API_KEY not available") + return + + from api.google_embedder_client import GoogleEmbedderClient + from adalflow.core.types import ModelType + + client = GoogleEmbedderClient() + + # Test single embedding + api_kwargs = client.convert_inputs_to_api_kwargs( + input="Hello world", + model_kwargs={"model": "text-embedding-004", "task_type": "SEMANTIC_SIMILARITY"}, + model_type=ModelType.EMBEDDER + ) + + response = client.call(api_kwargs, ModelType.EMBEDDER) + assert response is not None, "Google embedder should return response" + + # Parse the response + parsed = client.parse_embedding_response(response) + assert parsed.data is not None, "Parsed response should have data" + assert len(parsed.data) > 0, "Should have at least one embedding" + assert parsed.error is None, "Should not have errors" + + def test_openai_embedder_via_adalflow(self): + """Test OpenAI embedder through AdalFlow.""" + if not os.getenv('OPENAI_API_KEY'): + logger.warning("Skipping OpenAI embedder test - OPENAI_API_KEY not available") + return + + import adalflow as adal + from api.openai_client import OpenAIClient + + client = OpenAIClient() + embedder = adal.Embedder( + model_client=client, + model_kwargs={"model": "text-embedding-3-small", "dimensions": 256} + ) + + result = embedder("Hello world") + assert result is not None, "OpenAI embedder should return result" + assert hasattr(result, 'data'), "Result should have data attribute" + assert len(result.data) > 0, "Should have at least one embedding" + + +class TestDataPipelineFunctions: + """Test data pipeline functions that use embedders.""" + + def test_count_tokens(self, embedder_type=None): + """Test token counting with different embedder types.""" + from api.data_pipeline import count_tokens + + test_text = "This is a test string for token counting." + + if embedder_type is not None: + # Test with specific is_ollama_embedder value + token_count = count_tokens(test_text, is_ollama_embedder=embedder_type) + assert isinstance(token_count, int), "Token count should be an integer" + assert token_count > 0, "Token count should be positive" + else: + # Test with all values + for is_ollama in [None, True, False]: + token_count = count_tokens(test_text, is_ollama_embedder=is_ollama) + assert isinstance(token_count, int), "Token count should be an integer" + assert token_count > 0, "Token count should be positive" + + def test_prepare_data_pipeline(self, is_ollama=None): + """Test data pipeline preparation with different embedder types.""" + from api.data_pipeline import prepare_data_pipeline + + if is_ollama is not None: + try: + pipeline = prepare_data_pipeline(is_ollama_embedder=is_ollama) + assert pipeline is not None, "Data pipeline should be created" + assert hasattr(pipeline, '__call__'), "Pipeline should be callable" + except Exception as e: + # Some configurations might fail if services aren't available + logger.warning(f"Pipeline creation failed (might be expected): {e}") + else: + # Test with all values + for is_ollama_val in [None, True, False]: + try: + pipeline = prepare_data_pipeline(is_ollama_embedder=is_ollama_val) + assert pipeline is not None, "Data pipeline should be created" + assert hasattr(pipeline, '__call__'), "Pipeline should be callable" + except Exception as e: + logger.warning(f"Pipeline creation failed for is_ollama={is_ollama_val}: {e}") + + +class TestRAGIntegration: + """Test RAG class integration with different embedders.""" + + def test_rag_initialization(self): + """Test RAG initialization with different embedder configurations.""" + from api.rag import RAG + + # Test with default configuration + try: + rag = RAG(provider="google", model="gemini-1.5-flash") + assert rag is not None, "RAG should be initialized" + assert hasattr(rag, 'embedder'), "RAG should have embedder" + assert hasattr(rag, 'is_ollama_embedder'), "RAG should have is_ollama_embedder attribute" + except Exception as e: + logger.warning(f"RAG initialization failed (might be expected if keys missing): {e}") + + def test_rag_embedder_type_detection(self): + """Test that RAG correctly detects embedder type.""" + from api.rag import RAG + + try: + rag = RAG() + # Should have the embedder type detection logic + assert hasattr(rag, 'is_ollama_embedder'), "RAG should detect embedder type" + assert isinstance(rag.is_ollama_embedder, bool), "is_ollama_embedder should be boolean" + except Exception as e: + logger.warning(f"RAG initialization failed: {e}") + + +class TestEnvironmentVariableHandling: + """Test embedder selection via environment variables.""" + + def test_embedder_type_env_var(self, embedder_type=None): + """Test embedder selection via DEEPWIKI_EMBEDDER_TYPE environment variable.""" + import importlib + import api.config + + if embedder_type: + # Test specific embedder type + self._test_single_embedder_type(embedder_type) + else: + # Test all embedder types + for et in ['openai', 'google', 'ollama']: + self._test_single_embedder_type(et) + + def _test_single_embedder_type(self, embedder_type): + """Test a single embedder type.""" + import importlib + import api.config + + # Save original value + original_value = os.environ.get('DEEPWIKI_EMBEDDER_TYPE') + + try: + # Set environment variable + os.environ['DEEPWIKI_EMBEDDER_TYPE'] = embedder_type + + # Reload config to pick up new env var + importlib.reload(api.config) + + from api.config import EMBEDDER_TYPE, get_embedder_type + + assert EMBEDDER_TYPE == embedder_type, f"EMBEDDER_TYPE should be {embedder_type}" + assert get_embedder_type() == embedder_type, f"get_embedder_type() should return {embedder_type}" + + finally: + # Restore original value + if original_value is not None: + os.environ['DEEPWIKI_EMBEDDER_TYPE'] = original_value + elif 'DEEPWIKI_EMBEDDER_TYPE' in os.environ: + del os.environ['DEEPWIKI_EMBEDDER_TYPE'] + + # Reload config to restore original state + importlib.reload(api.config) + + +class TestIssuesIdentified: + """Test the specific issues identified in the codebase.""" + + def test_binary_assumptions_in_rag(self): + """Test that RAG doesn't make binary assumptions about embedders.""" + from api.rag import RAG + + # The current implementation only considers is_ollama_embedder + # This test documents the current behavior and will help verify fixes + try: + rag = RAG() + + # Current implementation only has is_ollama_embedder + assert hasattr(rag, 'is_ollama_embedder'), "RAG should have is_ollama_embedder" + + # This is the issue: no explicit support for Google embedder detection + # The fix should add proper embedder type detection + + except Exception as e: + logger.warning(f"RAG test failed: {e}") + + def test_binary_assumptions_in_data_pipeline(self): + """Test binary assumptions in data pipeline functions.""" + from api.data_pipeline import prepare_data_pipeline, count_tokens + + # These functions currently only consider is_ollama_embedder parameter + # This test documents the issue and will verify fixes + + # count_tokens only considers ollama vs non-ollama + token_count_ollama = count_tokens("test", is_ollama_embedder=True) + token_count_other = count_tokens("test", is_ollama_embedder=False) + + assert isinstance(token_count_ollama, int) + assert isinstance(token_count_other, int) + + # prepare_data_pipeline only accepts is_ollama_embedder parameter + try: + pipeline_ollama = prepare_data_pipeline(is_ollama_embedder=True) + pipeline_other = prepare_data_pipeline(is_ollama_embedder=False) + + assert pipeline_ollama is not None + assert pipeline_other is not None + except Exception as e: + logger.warning(f"Pipeline creation failed: {e}") + + +def run_all_tests(): + """Run all tests and return results.""" + logger.info("Running comprehensive embedder tests...") + + runner = TestRunner() + + # Test classes to run + test_classes = [ + TestEmbedderConfiguration, + TestEmbedderFactory, + TestEmbedderClients, + TestDataPipelineFunctions, + TestRAGIntegration, + TestEnvironmentVariableHandling, + TestIssuesIdentified + ] + + # Run all test classes + for test_class in test_classes: + logger.info(f"\n๐Ÿงช Running {test_class.__name__}...") + runner.run_test_class(test_class) + + # Run parametrized tests manually + logger.info("\n๐Ÿงช Running parametrized tests...") + + # Test embedder config with different types + config_test = TestEmbedderConfiguration() + for embedder_type in ['openai', 'google', 'ollama']: + runner.run_test( + lambda et=embedder_type: config_test.test_get_embedder_config(et), + f"TestEmbedderConfiguration.test_get_embedder_config[{embedder_type}]" + ) + + # Test token counting with different types + pipeline_test = TestDataPipelineFunctions() + for embedder_type in [None, True, False]: + runner.run_test( + lambda et=embedder_type: pipeline_test.test_count_tokens(et), + f"TestDataPipelineFunctions.test_count_tokens[{embedder_type}]" + ) + + # Test pipeline preparation with different types + for is_ollama in [None, True, False]: + runner.run_test( + lambda ol=is_ollama: pipeline_test.test_prepare_data_pipeline(ol), + f"TestDataPipelineFunctions.test_prepare_data_pipeline[{is_ollama}]" + ) + + # Test environment variable handling + env_test = TestEnvironmentVariableHandling() + for embedder_type in ['openai', 'google', 'ollama']: + runner.run_test( + lambda et=embedder_type: env_test.test_embedder_type_env_var(et), + f"TestEnvironmentVariableHandling.test_embedder_type_env_var[{embedder_type}]" + ) + + return runner.summary() + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/unit/test_google_embedder.py b/tests/unit/test_google_embedder.py new file mode 100644 index 00000000..368b22dd --- /dev/null +++ b/tests/unit/test_google_embedder.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Test script to reproduce and fix Google embedder 'list' object has no attribute 'embedding' error. +""" + +import os +import sys +import logging +from pathlib import Path + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +# Set up environment +from dotenv import load_dotenv +load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_google_embedder_client(): + """Test the Google embedder client directly.""" + logger.info("Testing Google embedder client...") + + try: + from api.google_embedder_client import GoogleEmbedderClient + from adalflow.core.types import ModelType + + # Initialize the client + client = GoogleEmbedderClient() + + # Test single embedding + logger.info("Testing single embedding...") + api_kwargs = client.convert_inputs_to_api_kwargs( + input="Hello world", + model_kwargs={"model": "text-embedding-004", "task_type": "SEMANTIC_SIMILARITY"}, + model_type=ModelType.EMBEDDER + ) + + response = client.call(api_kwargs, ModelType.EMBEDDER) + logger.info(f"Single embedding response type: {type(response)}") + logger.info(f"Single embedding response keys: {list(response.keys()) if isinstance(response, dict) else 'Not a dict'}") + + # Parse the response + parsed = client.parse_embedding_response(response) + logger.info(f"Parsed response data length: {len(parsed.data) if parsed.data else 0}") + logger.info(f"Parsed response error: {parsed.error}") + + # Test batch embedding + logger.info("Testing batch embedding...") + api_kwargs = client.convert_inputs_to_api_kwargs( + input=["Hello world", "Test embedding"], + model_kwargs={"model": "text-embedding-004", "task_type": "SEMANTIC_SIMILARITY"}, + model_type=ModelType.EMBEDDER + ) + + response = client.call(api_kwargs, ModelType.EMBEDDER) + logger.info(f"Batch embedding response type: {type(response)}") + logger.info(f"Batch embedding response keys: {list(response.keys()) if isinstance(response, dict) else 'Not a dict'}") + + # Parse the response + parsed = client.parse_embedding_response(response) + logger.info(f"Parsed batch response data length: {len(parsed.data) if parsed.data else 0}") + logger.info(f"Parsed batch response error: {parsed.error}") + + return True + + except Exception as e: + logger.error(f"Error testing Google embedder client: {e}") + import traceback + traceback.print_exc() + return False + +def test_adalflow_embedder(): + """Test the AdalFlow embedder with Google client.""" + logger.info("Testing AdalFlow embedder with Google client...") + + try: + import adalflow as adal + from api.google_embedder_client import GoogleEmbedderClient + + # Create embedder + client = GoogleEmbedderClient() + embedder = adal.Embedder( + model_client=client, + model_kwargs={ + "model": "text-embedding-004", + "task_type": "SEMANTIC_SIMILARITY" + } + ) + + # Test embedding + logger.info("Testing embedder with single input...") + result = embedder("Hello world") + logger.info(f"Embedder result type: {type(result)}") + logger.info(f"Embedder result: {result}") + + if hasattr(result, 'data'): + logger.info(f"Result data length: {len(result.data) if result.data else 0}") + + return True + + except Exception as e: + logger.error(f"Error testing AdalFlow embedder: {e}") + import traceback + traceback.print_exc() + return False + +def test_document_processing(): + """Test document processing with Google embedder.""" + logger.info("Testing document processing with Google embedder...") + + try: + from adalflow.core.types import Document + from adalflow.components.data_process import ToEmbeddings + from api.tools.embedder import get_embedder + + # Create some test documents + docs = [ + Document(text="This is a test document.", meta_data={"file_path": "test1.txt"}), + Document(text="Another test document here.", meta_data={"file_path": "test2.txt"}) + ] + + # Get the Google embedder + embedder = get_embedder(embedder_type='google') + logger.info(f"Embedder type: {type(embedder)}") + + # Process documents + embedder_transformer = ToEmbeddings(embedder=embedder, batch_size=100) + + # Transform documents + logger.info("Transforming documents...") + transformed_docs = embedder_transformer(docs) + + logger.info(f"Transformed docs type: {type(transformed_docs)}") + logger.info(f"Number of transformed docs: {len(transformed_docs)}") + + # Check the structure + for i, doc in enumerate(transformed_docs): + logger.info(f"Doc {i} type: {type(doc)}") + logger.info(f"Doc {i} attributes: {dir(doc)}") + if hasattr(doc, 'vector'): + logger.info(f"Doc {i} vector type: {type(doc.vector)}") + logger.info(f"Doc {i} vector length: {len(doc.vector) if doc.vector else 0}") + else: + logger.info(f"Doc {i} has no vector attribute") + + return transformed_docs + + except Exception as e: + logger.error(f"Error testing document processing: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + """Main test function.""" + logger.info("Starting Google embedder tests...") + + # Test 1: Direct client test + if not test_google_embedder_client(): + logger.error("Google embedder client test failed") + return False + + # Test 2: AdalFlow embedder test + if not test_adalflow_embedder(): + logger.error("AdalFlow embedder test failed") + return False + + # Test 3: Document processing test + result = test_document_processing() + if result is False: + logger.error("Document processing test failed") + return False + + logger.info("All tests completed successfully!") + return True + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) \ No newline at end of file