From 2437e0738fbfd36f1d3b4e744ef1b967f0c70d31 Mon Sep 17 00:00:00 2001 From: MHHukiewitz Date: Thu, 12 Dec 2024 16:42:16 +0100 Subject: [PATCH 1/5] Add POST /strategies/suggest-parameters to allow AI-supported suggestion of parameters --- conf/conf_client.yml | 2 +- pyproject.toml | 41 ++-- routers/__init__.py | 3 + routers/strategies.py | 94 ++++++- routers/strategies_models.py | 190 ++++++++++---- services/__init__.py | 3 + services/libert_ai_service.py | 422 ++++++++++++++++++++++++++++++++ tests/__init__.py | 3 + tests/test_libert_ai_service.py | 164 +++++++++++++ tests/test_strategies.py | 224 +++++++++++++++++ 10 files changed, 1072 insertions(+), 74 deletions(-) create mode 100644 services/libert_ai_service.py create mode 100644 tests/__init__.py create mode 100644 tests/test_libert_ai_service.py create mode 100644 tests/test_strategies.py diff --git a/conf/conf_client.yml b/conf/conf_client.yml index dd8e81e..856639f 100644 --- a/conf/conf_client.yml +++ b/conf/conf_client.yml @@ -107,7 +107,7 @@ gateway: gateway_api_host: localhost gateway_api_port: '15888' -certs_path: ./certs +certs_path: certs # Whether to enable aggregated order and trade data collection anonymized_metrics_mode: diff --git a/pyproject.toml b/pyproject.toml index 4ad399a..ea36189 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,26 @@ [build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.black] -line-length = 130 -target-version = ["py38"] - -[tool.isort] -line_length = 130 -profile = "black" -multi_line_output = 3 -include_trailing_comma = true -use_parentheses = true -ensure_newline_before_comments = true -combine_as_imports = true +[project] +name = "robotter-backend-api" +version = "0.1.0" +authors = [ + { name="Mike Henry" }, +] +description = "Robotter Backend API" +readme = "README.md" +requires-python = ">=3.10" -[tool.pre-commit] -repos = [ - { repo = "https://github.com/pre-commit/pre-commit-hooks", rev = "v3.4.0", hooks = [{ id = "check-yaml" }, { id = "end-of-file-fixer" }] }, - { repo = "https://github.com/psf/black", rev = "21.6b0", hooks = [{ id = "black" }] }, - { repo = "https://github.com/pre-commit/mirrors-isort", rev = "v5.9.3", hooks = [{ id = "isort" }] } +[tool.pytest.ini_options] +pythonpath = [ + "." +] +asyncio_mode = "strict" +testpaths = [ + "tests", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", ] diff --git a/routers/__init__.py b/routers/__init__.py index e69de29..3ca04e5 100644 --- a/routers/__init__.py +++ b/routers/__init__.py @@ -0,0 +1,3 @@ +""" +Router package initialization. +""" diff --git a/routers/strategies.py b/routers/strategies.py index ce0fff3..af93e5b 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -1,11 +1,91 @@ -from typing import Dict +import json +import os +from typing import Dict, List +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from fastapi import FastAPI +from contextlib import asynccontextmanager -from fastapi import APIRouter -from .strategies_models import StrategyParameter, get_all_strategy_maps +from services.libert_ai_service import LibertAIService +from routers.strategies_models import ( + ParameterSuggestionRequest, + ParameterSuggestionResponse, + StrategyConfig, + discover_strategies, + get_strategy_mapping +) -router = APIRouter(tags=["Strategies"]) +# Create a libert_ai_service instance +libert_ai_service = LibertAIService() +@asynccontextmanager +async def lifespan(app: FastAPI): + # Initialize contexts on startup + try: + # Load strategies using auto-discovery + strategies = discover_strategies() + await libert_ai_service.initialize_contexts(strategies) + except Exception as e: + print(f"Error initializing LibertAI contexts: {str(e)}") + yield -@router.get("/strategies", response_model=Dict[str, Dict[str, StrategyParameter]]) -async def get_strategies(): - return get_all_strategy_maps() +# Create the FastAPI app with the lifespan handler +app = FastAPI(lifespan=lifespan) +router = APIRouter() + +@router.get("/strategies") +async def get_strategies() -> Dict[str, StrategyConfig]: + """Get all available strategies and their configurations.""" + try: + # Use auto-discovery to get strategies + return discover_strategies() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/strategies/suggest-parameters") +async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSuggestionResponse: + """ + Suggest parameter values for a strategy based on the provided parameters. + Uses LibertAI to analyze and suggest optimal values for missing or requested parameters. + + If requested_parameters is provided, will only suggest values for those specific parameters. + Otherwise, will suggest values for all missing required parameters. + """ + try: + # Get strategy configuration using auto-discovery + strategies = discover_strategies() + + if request.strategy_id not in strategies: + raise HTTPException(status_code=404, detail=f"Strategy '{request.strategy_id}' not found") + + strategy = strategies[request.strategy_id] + + try: + # Get suggestions from LibertAI + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id=request.strategy_id, + strategy_config=strategy.parameters, + provided_params=request.parameters, + requested_params=request.requested_parameters + ) + + # Extract summary from the last suggestion if it exists + summary = "No suggestions available." + if suggestions: + # Remove the summary suggestion if it exists + if suggestions[-1].parameter_name.lower() == "summary": + summary = suggestions[-1].reasoning + suggestions = suggestions[:-1] + + return ParameterSuggestionResponse( + suggestions=suggestions, + summary=summary + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/routers/strategies_models.py b/routers/strategies_models.py index 033b7b5..27b16cc 100644 --- a/routers/strategies_models.py +++ b/routers/strategies_models.py @@ -1,6 +1,5 @@ from enum import Enum from typing import Any, Dict, Optional, Union, List - from pydantic import BaseModel, Field from decimal import Decimal from hummingbot.strategy_v2.controllers import MarketMakingControllerConfigBase, ControllerConfigBase, DirectionalTradingControllerConfigBase @@ -12,14 +11,26 @@ from pydantic.fields import ModelField from pydantic.main import ModelMetaclass -from bots.controllers.directional_trading.bollinger_v1 import BollingerV1ControllerConfig - logger = ( logging.getLogger(__name__) if __name__ != "__main__" else logging.getLogger("uvicorn") ) +class StrategyType(str, Enum): + DIRECTIONAL_TRADING = "directional_trading" + MARKET_MAKING = "market_making" + GENERIC = "generic" + +class StrategyMapping(BaseModel): + """Maps a strategy ID to its implementation details""" + id: str # e.g., "supertrend_v1" + config_class: str # e.g., "supertrendconfig" + module_path: str # e.g., "bots.controllers.directional_trading.supertrend_v1" + strategy_type: StrategyType + display_name: str # e.g., "Supertrend V1" + description: str = "" + class StrategyParameter(BaseModel): name: str group: str @@ -41,27 +52,29 @@ class StrategyParameter(BaseModel): is_integer: bool = False display_type: str = Field(default="input", description="Can be 'input', 'slider', 'dropdown', 'toggle', or 'date'") +class StrategyConfig(BaseModel): + """Complete strategy configuration including metadata and parameters""" + mapping: StrategyMapping + parameters: Dict[str, StrategyParameter] -def is_advanced_parameter(name: str) -> bool: - advanced_keywords = [ - "activation_bounds", "triple_barrier", "leverage", "dca", "macd", "natr", - "multiplier", "imbalance", "executor", "perp", "arbitrage" - ] - - simple_keywords = [ - "controller_name", "candles", "interval", "stop_loss", "take_profit", - "buy", "sell", "position_size", "time_limit", "spot" - ] - - name_lower = name.lower() - - if any(keyword in name_lower for keyword in advanced_keywords): - return True - - if any(keyword in name_lower for keyword in simple_keywords): - return False - - return True +class ParameterSuggestionRequest(BaseModel): + strategy_id: str + parameters: Dict[str, Any] + requested_parameters: Optional[List[str]] = Field( + default=None, + description="Optional list of specific parameters to get suggestions for. If not provided, will suggest values for all missing required parameters." + ) + +class ParameterSuggestion(BaseModel): + parameter_name: str + suggested_value: str + reasoning: str + differs_from_default: bool = False + differs_from_provided: bool = False + +class ParameterSuggestionResponse(BaseModel): + suggestions: List[ParameterSuggestion] + summary: str def get_strategy_display_info() -> Dict[str, Dict[str, str]]: """ @@ -85,7 +98,7 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: "pretty_name": "SuperTrend Strategy", "description": "Follows market trends to find good times to buy and sell." }, - + # Market Making Strategies "dman_maker_v2": { "pretty_name": "Smart Market Maker", @@ -99,7 +112,7 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: "pretty_name": "Simple Market Maker", "description": "Places basic buy and sell orders with fixed spreads." }, - + # Generic Strategies "spot_perp_arbitrage": { "pretty_name": "Spot-Futures Arbitrage", @@ -120,20 +133,20 @@ def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParam required=field.required or field.default is not None, is_advanced=is_advanced_parameter(name), group=determine_parameter_group(name), - pretty_name=name.replace('_', ' ').title(), - description="", + pretty_name=name.replace('_', ' ').title(), + description="", ) - + # Get strategy display info strategy_info = get_strategy_display_info() - + # Try to find matching strategy info for strategy_name, info in strategy_info.items(): if strategy_name in name.lower(): param.pretty_name = info["pretty_name"] param.description = info["description"] break - + # structure of field client_data = field.field_info.extra.get('client_data') if client_data is not None and param.prompt == "": @@ -145,17 +158,23 @@ def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParam param.display_type = "input" # Check for gt (greater than) and lt (less than) in field definition - if hasattr(field.field_info, 'gt'): - param.min_value = field.field_info.gt - if hasattr(field.field_info, 'lt'): - param.max_value = field.field_info.lt + if hasattr(field.field_info, 'ge'): + param.min_value = field.field_info.ge + elif hasattr(field.field_info, 'gt'): + param.min_value = field.field_info.gt + (1 if isinstance(field.field_info.gt, int) else Decimal('0')) + + if hasattr(field.field_info, 'le'): + param.max_value = field.field_info.le + elif hasattr(field.field_info, 'lt'): + param.max_value = field.field_info.lt - (1 if isinstance(field.field_info.lt, int) else Decimal('0')) + # Set display_type to "slider" only if both min and max values are present if param.min_value is not None and param.max_value is not None: param.display_type = "slider" elif param.type == "bool": param.display_type = "toggle" - + if "connector" in name.lower(): param.is_connector = True if "trading_pair" in name.lower(): @@ -170,11 +189,11 @@ def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParam param.min_value = Decimal(0) if any(word in name.lower() for word in ["time", "interval", "duration"]): param.is_timespan = True - param.min_value = 0 - if param.type == "int": - param.is_integer = True - if any(word in name.lower() for word in ["executors", "workers"]): - param.display_type = "slider" + param.min_value = 0 + if param.type == "int": + param.is_integer = True + if any(word in name.lower() for word in ["executors", "workers"]): + param.display_type = "slider" param.min_value = 1 try: if issubclass(field.type_, Enum): @@ -206,10 +225,45 @@ def determine_parameter_group(name: str) -> str: else: return "Other" +def snake_case_to_real_name(snake_case: str) -> str: + return " ".join([word.capitalize() for word in snake_case.split("_")]) + +def infer_strategy_type(module_path: str, config_class: Any) -> StrategyType: + """Infer the strategy type from the module path and config class""" + if "directional_trading" in module_path: + return StrategyType.DIRECTIONAL_TRADING + elif "market_making" in module_path: + return StrategyType.MARKET_MAKING + else: + return StrategyType.GENERIC + +def generate_strategy_mapping(module_path: str, config_class: Any) -> StrategyMapping: + """Generate a strategy mapping from a config class""" + # Extract strategy ID from module path (e.g., "supertrend_v1" from "bots.controllers.directional_trading.supertrend_v1") + strategy_id = module_path.split(".")[-1] + + # Get strategy type + strategy_type = infer_strategy_type(module_path, config_class) + + # Generate display name + display_name = " ".join(word.capitalize() for word in strategy_id.split("_")) + + # Get description from class docstring + description = config_class.__doc__ or "" + + return StrategyMapping( + id=strategy_id, + config_class=config_class.__name__, + module_path=module_path, + strategy_type=strategy_type, + display_name=display_name, + description=description + ) @functools.lru_cache(maxsize=1) -def get_all_strategy_maps() -> Dict[str, Dict[str, StrategyParameter]]: - strategy_maps = {} +def discover_strategies() -> Dict[str, StrategyConfig]: + """Discover and load all available strategies""" + strategy_configs = {} controllers_dir = "bots/controllers" for root, dirs, files in os.walk(controllers_dir): @@ -230,17 +284,59 @@ def get_all_strategy_maps() -> Dict[str, Dict[str, StrategyParameter]]: and obj is not DirectionalTradingControllerConfigBase ): assert isinstance(obj, ModelMetaclass) - strategy_name = obj.controller_name if hasattr(obj, 'controller_name') else name.lower() + + # Generate mapping + mapping = generate_strategy_mapping(module_path, obj) + + # Convert parameters parameters = {} for field_name, field in obj.__fields__.items(): param = convert_to_strategy_parameter(field_name, field) parameters[field_name] = param - strategy_maps[strategy_name] = parameters + # Create complete strategy config + strategy_configs[mapping.id] = StrategyConfig( + mapping=mapping, + parameters=parameters + ) + except ImportError as e: - print(f"Error importing module {module_path}: {e}") + logger.error(f"Error importing module {module_path}: {e}") except Exception as e: - print(f"Unexpected error processing {module_path}: {e}") + logger.error(f"Unexpected error processing {module_path}: {e}") import traceback traceback.print_exc() - return strategy_maps + + return strategy_configs + +def get_strategy_mapping(strategy_id: str) -> Optional[StrategyMapping]: + """Get strategy mapping by ID""" + strategies = discover_strategies() + strategy = strategies.get(strategy_id) + return strategy.mapping if strategy else None + +def get_strategy_module_path(strategy_id: str) -> Optional[str]: + """Get the module path for a strategy""" + mapping = get_strategy_mapping(strategy_id) + return mapping.module_path if mapping else None + +def is_advanced_parameter(name: str) -> bool: + advanced_keywords = [ + "activation_bounds", "triple_barrier", "leverage", "dca", "macd", "natr", + "multiplier", "imbalance", "executor", "perp", "arbitrage" + ] + + simple_keywords = [ + "controller_name", "candles", "interval", "stop_loss", "take_profit", + "buy", "sell", "position_size", "time_limit", "spot" + ] + + name_lower = name.lower() + + if any(keyword in name_lower for keyword in advanced_keywords): + return True + + if any(keyword in name_lower for keyword in simple_keywords): + return False + + return True diff --git a/services/__init__.py b/services/__init__.py index e69de29..9c0bf67 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -0,0 +1,3 @@ +""" +Services package initialization. +""" diff --git a/services/libert_ai_service.py b/services/libert_ai_service.py new file mode 100644 index 0000000..127a2a3 --- /dev/null +++ b/services/libert_ai_service.py @@ -0,0 +1,422 @@ +import os +import json +import logging +import aiohttp +import inspect +import importlib +from typing import Dict, Any, List, Optional +from routers.strategies_models import ( + ParameterSuggestion, + StrategyConfig, + StrategyMapping, + discover_strategies +) + +logger = logging.getLogger(__name__) + +class LibertAIService: + def __init__(self): + # Hermes 2 pro + self.api_url = "https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion" + + self.strategy_slot_map: Dict[str, int] = {} # Maps strategy IDs to their slot IDs + self.next_slot_id = 0 + + async def initialize_contexts(self, strategies: Dict[str, StrategyConfig]): + """Initialize context slots for system prompt and each strategy.""" + try: + logger.info("Starting context initialization...") + + # Initialize system prompt in slot -1 + logger.info("Initializing system context...") + await self._initialize_system_context() + + # Initialize each strategy's context + for strategy_id, strategy_config in strategies.items(): + logger.info(f"Initializing context for strategy: {strategy_id}") + slot_id = self.next_slot_id + self.strategy_slot_map[strategy_id] = slot_id + self.next_slot_id += 1 + + # Load strategy implementation code + strategy_code = await self._load_strategy_code(strategy_config.mapping) + logger.info(f"Loaded strategy code for {strategy_id}, code length: {len(strategy_code)}") + + await self._initialize_strategy_context( + strategy_mapping=strategy_config.mapping, + strategy_config=strategy_config.parameters, + strategy_code=strategy_code, + slot_id=slot_id + ) + + logger.info(f"Context initialization complete. Strategy slot map: {self.strategy_slot_map}") + + except Exception as e: + logger.error(f"Error initializing contexts: {str(e)}") + raise + + async def _load_strategy_code(self, mapping: StrategyMapping) -> str: + """Load the strategy implementation code using the strategy mapping.""" + try: + # Import the module using the mapping's module path + module = importlib.import_module(mapping.module_path) + + # Get all classes in the module + strategy_classes = inspect.getmembers( + module, + lambda member: ( + inspect.isclass(member) + and member.__module__ == module.__name__ + and not member.__name__.endswith('Config') + ) + ) + + if not strategy_classes: + raise ValueError(f"No strategy class found in {mapping.module_path}") + + # Get the source code of the strategy class + strategy_class = strategy_classes[0][1] # Take the first class + source_code = inspect.getsource(strategy_class) + + return source_code + + except Exception as e: + logger.error(f"Error loading strategy code for {mapping.id}: {str(e)}") + return f"# Strategy implementation code not found for {mapping.id}" + + async def _initialize_system_context(self): + """Initialize the system prompt in slot -1.""" + system_prompt = """<|im_start|>system +You are an expert trading strategy advisor. Your task is to analyze trading strategies and suggest optimal parameter values. + +For each parameter suggestion, you must: +1. Suggest an appropriate value based on the strategy type and configuration +2. Provide a clear explanation of why this value would be appropriate +3. Consider potential risks and market conditions +4. Take into account the strategy's implementation code and logic + +Format each suggestion exactly as follows: +PARAMETER: [parameter_name] +VALUE: [suggested_value] +REASONING: [detailed explanation] +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + await session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json={ + "prompt": system_prompt, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 100, + "stop": ["<|im_end|>"] + } + ) + except Exception as e: + print(f"ERROR: Error initializing system context: {str(e)}") + raise + + async def _initialize_strategy_context( + self, + strategy_mapping: StrategyMapping, + strategy_config: Dict[str, Any], + strategy_code: str, + slot_id: int + ): + """Initialize context for a specific strategy.""" + # Convert strategy parameters to a serializable format + serializable_config = { + name: { + "name": param.name, + "group": param.group, + "type": param.type, + "prompt": param.prompt, + "default": str(param.default) if param.default is not None else None, + "required": param.required, + "min_value": str(param.min_value) if param.min_value is not None else None, + "max_value": str(param.max_value) if param.max_value is not None else None, + "is_advanced": param.is_advanced, + "display_type": param.display_type + } + for name, param in strategy_config.items() + } + + strategy_context = f"""<|im_start|>user +Trading Strategy: {strategy_mapping.display_name} +Type: {strategy_mapping.strategy_type.value} +Description: {strategy_mapping.description} + +Strategy Configuration Schema: +{json.dumps(serializable_config, indent=2)} + +Strategy Implementation: +```python +{strategy_code} +``` + +This strategy's configuration defines the parameters and their constraints, while the implementation shows how these parameters are used in the trading logic. Use both to make informed suggestions about parameter values. +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + await session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json={ + "prompt": strategy_context, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 100, + "stop": ["<|im_end|>"], + "slot_id": slot_id, + "parent_slot_id": -1, + } + ) + except Exception as e: + print(f"ERROR: Error initializing strategy context for {strategy_mapping.id}: {str(e)}") + raise + + async def get_parameter_suggestions( + self, + strategy_id: str, + strategy_config: Dict[str, Any], + provided_params: Dict[str, Any], + requested_params: Optional[List[str]] = None + ) -> List[ParameterSuggestion]: + """Get parameter suggestions from LibertAI. + + Args: + strategy_id: ID of the strategy + strategy_config: Full strategy configuration + provided_params: Parameters already provided by the user + requested_params: Optional list of specific parameters to get suggestions for + """ + print("\n=== Getting Parameter Suggestions ===") + print(f"Strategy ID: {strategy_id}") + print(f"Provided parameters: {json.dumps(provided_params, indent=2)}") + print(f"Requested parameters: {requested_params}") + + # Get strategy configuration + strategies = discover_strategies() + strategy = strategies.get(strategy_id) + if not strategy: + print(f"ERROR: No strategy found with ID {strategy_id}") + return [] + + # Identify missing required parameters and optional parameters + missing_required = [] + optional_params = [] + + # If specific parameters are requested, only consider those + params_to_check = requested_params if requested_params else strategy_config.keys() + + for param_name in params_to_check: + if param_name not in provided_params: + param_config = strategy_config.get(param_name) + if param_config: + if param_config.required: + missing_required.append(param_name) + else: + optional_params.append(param_name) + + print(f"Missing required parameters: {missing_required}") + print(f"Optional parameters: {optional_params}") + + # Get the strategy's slot ID + slot_id = self.strategy_slot_map.get(strategy_id) + if slot_id is None: + print(f"ERROR: No cached context found for strategy {strategy_id}") + return [] + + # Convert parameters to a serializable format + serializable_params = { + name: str(value) if hasattr(value, "__str__") else value + for name, value in provided_params.items() + } + + # Update the prompt to be more explicit about the format and requested parameters + optional_params_text = f"Optional Parameters That Could Be Set:\n{', '.join(optional_params) if optional_params else 'None'}" if not requested_params else "" + + request_prompt = f"""<|im_start|>user +Strategy: {strategy.mapping.display_name} +Type: {strategy.mapping.strategy_type.value} + +Currently Provided Parameters: +{json.dumps(serializable_params, indent=2)} + +{"Parameters to Suggest:" if requested_params else "Missing Required Parameters:"} +{', '.join(requested_params) if requested_params else ', '.join(missing_required) if missing_required else 'None'} + +{optional_params_text} + +Please suggest optimal values for {"the requested" if requested_params else "the missing"} parameters using exactly this format for each parameter: + +PARAMETER: [parameter_name] +VALUE: [suggested_value] +REASONING: [detailed explanation of why this value is appropriate] + +End with a summary: +SUMMARY: [overall explanation of the suggested configuration] + +Do not include code blocks or other formats. Use only the PARAMETER/VALUE/REASONING structure. +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + print(f"\nSending request to LibertAI API...") + print(f"Request prompt:\n{request_prompt}") + + request_payload = { + "slot_id": self.next_slot_id, + "parent_slot_id": slot_id, + "prompt": request_prompt, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 1500, + "stop": ["<|im_end|>"] + } + + async with session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json=request_payload + ) as response: + if response.status != 200: + print(f"ERROR: API returned status {response.status}") + response_text = await response.text() + print(f"Response body: {response_text}") + return [] + + result = await response.json() + print(f"\nReceived response from API: {json.dumps(result, indent=2)}") + return self._parse_ai_response( + {"choices": [{"message": {"content": result["content"]}}]}, + strategy_config=strategy_config, + provided_params=provided_params + ) + + except Exception as e: + print(f"ERROR: Exception during API call: {str(e)}") + return [] + + def _parse_ai_response(self, ai_response: Dict[str, Any], strategy_config: Dict[str, Any], provided_params: Dict[str, Any]) -> List[ParameterSuggestion]: + print("\n=== Parsing AI Response ===") + try: + content = ai_response["choices"][0]["message"]["content"] + print(f"Response content preview: {content[:200]}...") + + suggestions = [] + seen_params = set() + summary = None + + # Create a map of default values and provided values for comparison + default_values = { + name: str(param.default) if param.default is not None else None + for name, param in strategy_config.items() + } + + provided_values = { + name: str(value) if hasattr(value, "__str__") else str(value) + for name, value in provided_params.items() + } + + if "PARAMETER:" in content: + print("Found structured format with PARAMETER/VALUE/REASONING") + parameter_sections = content.split("PARAMETER:") + + for section in parameter_sections[1:]: + lines = section.strip().split("\n") + param_name = lines[0].strip() + + # Initialize collectors for multi-line values + value_lines = [] + reasoning_lines = [] + collecting_value = False + collecting_reasoning = False + + # Process remaining lines + for line in lines[1:]: + line = line.strip() + + if line.startswith("VALUE:"): + collecting_value = True + collecting_reasoning = False + value_lines.append(line.replace("VALUE:", "").strip()) + elif line.startswith("REASONING:"): + collecting_value = False + collecting_reasoning = True + reasoning_lines.append(line.replace("REASONING:", "").strip()) + elif line.startswith("SUMMARY:"): + collecting_value = False + collecting_reasoning = False + summary = line.replace("SUMMARY:", "").strip() + else: + # Continue collecting multi-line values + if collecting_value and line: + value_lines.append(line) + elif collecting_reasoning and line: + reasoning_lines.append(line) + + # Process collected values + if param_name and value_lines and param_name not in seen_params: + seen_params.add(param_name) + + # Join multi-line values and try to parse as JSON if it looks like a JSON structure + value = "\n".join(value_lines) + if value.strip().startswith("{") and value.strip().endswith("}"): + try: + parsed_value = json.loads(value) + value = json.dumps(parsed_value) + except json.JSONDecodeError: + pass + + # Compare with default and provided values + differs_from_default = ( + param_name in default_values and + default_values[param_name] is not None and + value != default_values[param_name] + ) + differs_from_provided = ( + param_name in provided_values and + value != provided_values[param_name] + ) + + suggestions.append(ParameterSuggestion( + parameter_name=param_name, + suggested_value=value, + reasoning="\n".join(reasoning_lines) if reasoning_lines else "No reasoning provided", + differs_from_default=differs_from_default, + differs_from_provided=differs_from_provided + )) + + if summary: + suggestions.append(ParameterSuggestion( + parameter_name="summary", + suggested_value=summary, + reasoning="Summary of the suggested configuration", + differs_from_default=False, + differs_from_provided=False + )) + + print(f"\nTotal suggestions parsed: {len(suggestions)}") + for s in suggestions: + print(f"- {s.parameter_name}: {s.suggested_value}") + if s.differs_from_default: + print(f" (differs from default: {s.differs_from_default})") + if s.differs_from_provided: + print(f" (differs from provided: {s.differs_from_provided})") + + return suggestions + + except Exception as e: + print(f"ERROR: Failed to parse AI response: {str(e)}") + print(f"Raw response: {json.dumps(ai_response, indent=2)}") + return [] \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9d55c93 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests package initialization. +""" \ No newline at end of file diff --git a/tests/test_libert_ai_service.py b/tests/test_libert_ai_service.py new file mode 100644 index 0000000..b8c38c6 --- /dev/null +++ b/tests/test_libert_ai_service.py @@ -0,0 +1,164 @@ +import pytest +from services.libert_ai_service import LibertAIService +from routers.strategies_models import ( + ParameterSuggestion, + discover_strategies, +) +from typing import Any +from dataclasses import dataclass + +@pytest.fixture +def libert_ai_service(): + service = LibertAIService() + return service + +@pytest.fixture +def strategy_configs(): + """Load all available strategies""" + return discover_strategies() + +@pytest.fixture +def bollinger_strategy(strategy_configs): + """Get the Bollinger strategy configuration""" + return strategy_configs["bollinger_v1"] + +@pytest.mark.asyncio +async def test_initialize_system_context(libert_ai_service): + """Test system context initialization""" + await libert_ai_service._initialize_system_context() + +@pytest.mark.asyncio +async def test_initialize_strategy_context(libert_ai_service, bollinger_strategy): + """Test strategy context initialization""" + with open(f"bots/{bollinger_strategy.mapping.module_path.split('bots.')[-1].replace('.', '/')}.py", "r") as f: + strategy_code = f.read() + + await libert_ai_service._initialize_strategy_context( + strategy_mapping=bollinger_strategy.mapping, + strategy_config=bollinger_strategy.parameters, + strategy_code=strategy_code, + slot_id=0 + ) + +@pytest.mark.asyncio +async def test_get_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Test parameter suggestion generation""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0} + ) + + # Verify suggestions are part of the bollinger_strategy.parameters + for suggestion in suggestions: + assert suggestion.parameter_name in bollinger_strategy.parameters or suggestion.parameter_name == "summary" + +@pytest.mark.asyncio +async def test_parse_ai_response(libert_ai_service): + """Test AI response parsing""" + ai_response = { + "choices": [{ + "message": { + "content": """ +PARAMETER: bb_length +VALUE: 100 +REASONING: Standard length for Bollinger Bands calculation provides reliable signals while filtering out noise. + +PARAMETER: bb_long_threshold +VALUE: 0.2 +REASONING: Enter long positions when price is 20% below the middle band, indicating oversold conditions. + +SUMMARY: These parameters are optimized for mean reversion trading using Bollinger Bands. +""" + } + }] + } + + # Mock strategy config with proper parameter objects + @dataclass + class MockParameter: + name: str + default: Any + required: bool = True + type: str = "float" + + strategy_config = { + "bb_length": MockParameter( + name="BB Length", + default=20, + type="int" + ), + "bb_long_threshold": MockParameter( + name="BB Long Threshold", + default=0.1, + type="float" + ) + } + + provided_params = { + "bb_length": 20 + } + + suggestions = libert_ai_service._parse_ai_response( + ai_response, + strategy_config=strategy_config, + provided_params=provided_params + ) + + assert len(suggestions) == 3 # 2 parameters + summary + assert all(isinstance(s, ParameterSuggestion) for s in suggestions) + assert suggestions[0].parameter_name == "bb_length" + assert suggestions[0].suggested_value == "100" + assert suggestions[0].differs_from_default is True # 100 vs default 20 + assert suggestions[0].differs_from_provided is True # 100 vs provided 20 + assert suggestions[1].parameter_name == "bb_long_threshold" + assert suggestions[1].suggested_value == "0.2" + assert suggestions[1].differs_from_default is True # 0.2 vs default 0.1 + assert suggestions[1].differs_from_provided is False # Not provided + assert suggestions[2].parameter_name == "summary" + assert suggestions[2].suggested_value == "These parameters are optimized for mean reversion trading using Bollinger Bands." + +@pytest.mark.asyncio +async def test_parse_ai_response_handles_invalid_format(libert_ai_service): + """Test handling of invalid AI response format""" + invalid_response = { + "choices": [{ + "message": { + "content": "Invalid format response" + } + }] + } + + # Mock empty config with proper parameter objects + strategy_config = {} + provided_params = {} + + suggestions = libert_ai_service._parse_ai_response( + invalid_response, + strategy_config=strategy_config, + provided_params=provided_params + ) + assert suggestions == [] + +@pytest.mark.asyncio +async def test_get_specific_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Test getting suggestions for specific parameters""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + # Request suggestions for specific parameters + requested_params = ["bb_length", "bb_long_threshold"] + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0}, + requested_params=requested_params + ) + + # Verify we only got suggestions for the requested parameters (plus summary) + assert len(suggestions) == 3 # 2 requested parameters + summary + suggestion_params = {s.parameter_name for s in suggestions if s.parameter_name != "summary"} + assert suggestion_params == set(requested_params) \ No newline at end of file diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..ca0e8cc --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,224 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from decimal import Decimal +from typing import Dict, Any +from pydantic import BaseModel, Field +from hummingbot.strategy_v2.controllers import ControllerConfigBase + +from routers.strategies_models import ( + StrategyType, + StrategyMapping, + StrategyParameter, + StrategyConfig, + discover_strategies, + generate_strategy_mapping, + convert_to_strategy_parameter, + infer_strategy_type +) + +# Mock strategy config class for testing +class MockStrategyConfig(ControllerConfigBase): + """Test strategy for unit testing""" + controller_name = "test_strategy_v1" + + stop_loss: Decimal = Field( + default=Decimal("0.03"), + description="Stop loss percentage", + ge=Decimal("0"), + le=Decimal("1") + ) + take_profit: Decimal = Field( + default=Decimal("0.02"), + description="Take profit percentage", + ge=Decimal("0"), + le=Decimal("1") + ) + time_limit: int = Field( + default=2700, + description="Time limit in seconds", + gt=0 + ) + leverage: int = Field( + default=20, + description="Leverage multiplier", + gt=0 + ) + trading_pair: str = Field( + default="BTC-USDT", + description="Trading pair to use" + ) + +# Test data +MOCK_MODULE_PATH = "bots.controllers.directional_trading.test_strategy_v1" + +@pytest.fixture +def mock_strategy_config(): + return MockStrategyConfig + +@pytest.fixture(autouse=True) +def mock_importlib(): + with patch("importlib.import_module") as mock: + mock.return_value = MagicMock( + __name__="test_module", + MockStrategyConfig=MockStrategyConfig + ) + yield mock + +@pytest.fixture(autouse=True) +def mock_os_walk(): + with patch("os.walk") as mock: + mock.return_value = [ + ("bots/controllers/directional_trading", [], ["test_strategy_v1.py"]), + ] + yield mock + +@pytest.fixture(autouse=True) +def mock_discover_strategies(): + """Mock discover_strategies to return our test data""" + with patch("routers.strategies_models.discover_strategies", autospec=True) as mock: + mock.return_value = { + "test_strategy_v1": StrategyConfig( + mapping=StrategyMapping( + id="test_strategy_v1", + config_class="MockStrategyConfig", + module_path=MOCK_MODULE_PATH, + strategy_type=StrategyType.DIRECTIONAL_TRADING, + display_name="Test Strategy V1", + description="Test strategy for unit testing" + ), + parameters={ + "stop_loss": StrategyParameter( + name="Stop Loss", + group="Risk Management", + type="Decimal", + prompt="Enter stop loss value", + default=Decimal("0.03"), + required=True, + min_value=Decimal("0"), + max_value=Decimal("1") + ), + "take_profit": StrategyParameter( + name="Take Profit", + group="Risk Management", + type="Decimal", + prompt="Enter take profit value", + default=Decimal("0.02"), + required=True, + min_value=Decimal("0"), + max_value=Decimal("1") + ), + "time_limit": StrategyParameter( + name="Time Limit", + group="General Settings", + type="int", + prompt="Enter time limit in seconds", + default=2700, + required=True, + min_value=0 + ), + "leverage": StrategyParameter( + name="Leverage", + group="Risk Management", + type="int", + prompt="Enter leverage multiplier", + default=20, + required=True, + min_value=1, + is_advanced=True + ), + "trading_pair": StrategyParameter( + name="Trading Pair", + group="General Settings", + type="str", + prompt="Enter trading pair", + default="BTC-USDT", + required=True, + is_trading_pair=True + ) + } + ) + } + yield mock + +def test_infer_strategy_type(): + """Test strategy type inference from module path""" + assert infer_strategy_type("bots.controllers.directional_trading.test", None) == StrategyType.DIRECTIONAL_TRADING + assert infer_strategy_type("bots.controllers.market_making.test", None) == StrategyType.MARKET_MAKING + assert infer_strategy_type("bots.controllers.generic.test", None) == StrategyType.GENERIC + +def test_generate_strategy_mapping(): + """Test strategy mapping generation""" + mapping = generate_strategy_mapping(MOCK_MODULE_PATH, MockStrategyConfig) + + assert mapping.id == "test_strategy_v1" + assert mapping.config_class == "MockStrategyConfig" + assert mapping.module_path == MOCK_MODULE_PATH + assert mapping.strategy_type == StrategyType.DIRECTIONAL_TRADING + assert mapping.display_name == "Test Strategy V1" + assert "Test strategy for unit testing" in mapping.description + +def test_convert_to_strategy_parameter(): + """Test parameter conversion from config field""" + # Get a field from the mock config + field = MockStrategyConfig.__fields__["stop_loss"] + param = convert_to_strategy_parameter("stop_loss", field) + + assert param.name == "Stop Loss" + assert param.group == "Risk Management" + assert param.type == "ConstrainedDecimalValue" # We want the base type, not the constrained type + assert param.default == Decimal("0.03") + assert param.required is True + assert param.min_value == Decimal("0") + assert param.max_value == Decimal("1") + assert param.display_type == "slider" + +@pytest.mark.asyncio +async def test_discover_strategies(): + """Test strategy auto-discovery""" + strategies = discover_strategies() + + assert len(strategies) == 9 + assert "bollinger_v1" in strategies + + strategy = strategies["bollinger_v1"] + assert isinstance(strategy, StrategyConfig) + assert strategy.mapping.id == "bollinger_v1" + + # Check some parameters + assert "stop_loss" in strategy.parameters + assert "take_profit" in strategy.parameters + assert strategy.parameters["leverage"].is_advanced is True + assert strategy.parameters["trading_pair"].is_trading_pair is True + + +def test_parameter_validation(): + """Test parameter validation in strategy config""" + # Test required parameters + with pytest.raises(ValueError): + MockStrategyConfig( + stop_loss=None, # Required parameter missing + take_profit=Decimal("0.02"), + time_limit=2700, + leverage=20, + trading_pair="BTC-USDT" + ) + + # Test parameter constraints + with pytest.raises(ValueError): + MockStrategyConfig( + stop_loss=Decimal("-0.03"), # Negative value not allowed + take_profit=Decimal("0.02"), + time_limit=2700, + leverage=20, + trading_pair="BTC-USDT" + ) + +def test_strategy_type_enum(): + """Test StrategyType enum values""" + assert StrategyType.DIRECTIONAL_TRADING == "directional_trading" + assert StrategyType.MARKET_MAKING == "market_making" + assert StrategyType.GENERIC == "generic" + + # Test that invalid types are not allowed + with pytest.raises(ValueError): + StrategyType("invalid_type") \ No newline at end of file From b3057cacb1c7ddcc9719f58635ceeee7f997a40d Mon Sep 17 00:00:00 2001 From: MHHukiewitz Date: Mon, 16 Dec 2024 16:55:21 +0100 Subject: [PATCH 2/5] Add fallbacks --- routers/strategies.py | 11 +++++++++++ services/libert_ai_service.py | 1 - 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/routers/strategies.py b/routers/strategies.py index af93e5b..3008f52 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -23,11 +23,17 @@ async def lifespan(app: FastAPI): # Initialize contexts on startup try: # Load strategies using auto-discovery + print("Initializing LibertAI contexts...") strategies = discover_strategies() await libert_ai_service.initialize_contexts(strategies) + print(f"Successfully initialized contexts for {len(strategies)} strategies") except Exception as e: print(f"Error initializing LibertAI contexts: {str(e)}") + # Re-raise the exception to prevent app startup if context initialization fails + raise yield + # Cleanup on shutdown if needed + print("Cleaning up LibertAI contexts...") # Create the FastAPI app with the lifespan handler app = FastAPI(lifespan=lifespan) @@ -60,6 +66,11 @@ async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSu strategy = strategies[request.strategy_id] + # Ensure context is initialized for this strategy + if request.strategy_id not in libert_ai_service.strategy_slot_map: + print(f"Re-initializing context for strategy {request.strategy_id}") + await libert_ai_service.initialize_contexts({request.strategy_id: strategy}) + try: # Get suggestions from LibertAI suggestions = await libert_ai_service.get_parameter_suggestions( diff --git a/services/libert_ai_service.py b/services/libert_ai_service.py index 127a2a3..cbe2a8d 100644 --- a/services/libert_ai_service.py +++ b/services/libert_ai_service.py @@ -1,4 +1,3 @@ -import os import json import logging import aiohttp From 77e1d4602312e0ed003671fd86a41849a089f3e1 Mon Sep 17 00:00:00 2001 From: MHHukiewitz Date: Mon, 16 Dec 2024 18:26:28 +0100 Subject: [PATCH 3/5] Fix tests for new models --- tests/test_strategies.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_strategies.py b/tests/test_strategies.py index ca0e8cc..91f3b16 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -88,7 +88,9 @@ def mock_discover_strategies(): ), parameters={ "stop_loss": StrategyParameter( - name="Stop Loss", + name="stop_loss", + pretty_name="Stop Loss", + description="Stop loss percentage", group="Risk Management", type="Decimal", prompt="Enter stop loss value", @@ -98,7 +100,9 @@ def mock_discover_strategies(): max_value=Decimal("1") ), "take_profit": StrategyParameter( - name="Take Profit", + name="take_profit", + pretty_name="Take Profit", + description="Take profit percentage", group="Risk Management", type="Decimal", prompt="Enter take profit value", @@ -108,7 +112,9 @@ def mock_discover_strategies(): max_value=Decimal("1") ), "time_limit": StrategyParameter( - name="Time Limit", + name="time_limit", + pretty_name="Time Limit", + description="Time limit in seconds", group="General Settings", type="int", prompt="Enter time limit in seconds", @@ -117,7 +123,9 @@ def mock_discover_strategies(): min_value=0 ), "leverage": StrategyParameter( - name="Leverage", + name="leverage", + pretty_name="Leverage", + description="Leverage multiplier", group="Risk Management", type="int", prompt="Enter leverage multiplier", @@ -127,7 +135,9 @@ def mock_discover_strategies(): is_advanced=True ), "trading_pair": StrategyParameter( - name="Trading Pair", + name="trading_pair", + pretty_name="Trading Pair", + description="Trading pair to use", group="General Settings", type="str", prompt="Enter trading pair", @@ -163,7 +173,7 @@ def test_convert_to_strategy_parameter(): field = MockStrategyConfig.__fields__["stop_loss"] param = convert_to_strategy_parameter("stop_loss", field) - assert param.name == "Stop Loss" + assert param.pretty_name == "Stop Loss" assert param.group == "Risk Management" assert param.type == "ConstrainedDecimalValue" # We want the base type, not the constrained type assert param.default == Decimal("0.03") From 77dbf0ad929a732c82fb06f7fe5cc5e8290e75a0 Mon Sep 17 00:00:00 2001 From: MHHukiewitz Date: Mon, 16 Dec 2024 18:43:47 +0100 Subject: [PATCH 4/5] Reorganize strategy models and API types --- routers/strategies.py | 8 +- routers/strategies_models.py | 351 ++++++++++++++++++++++------------- 2 files changed, 227 insertions(+), 132 deletions(-) diff --git a/routers/strategies.py b/routers/strategies.py index 3008f52..9366d77 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -1,8 +1,5 @@ -import json -import os -from typing import Dict, List +from typing import Dict from fastapi import APIRouter, HTTPException -from pydantic import BaseModel from fastapi import FastAPI from contextlib import asynccontextmanager @@ -11,8 +8,7 @@ ParameterSuggestionRequest, ParameterSuggestionResponse, StrategyConfig, - discover_strategies, - get_strategy_mapping + discover_strategies ) # Create a libert_ai_service instance diff --git a/routers/strategies_models.py b/routers/strategies_models.py index 27b16cc..070ad7a 100644 --- a/routers/strategies_models.py +++ b/routers/strategies_models.py @@ -22,6 +22,81 @@ class StrategyType(str, Enum): MARKET_MAKING = "market_making" GENERIC = "generic" +class DisplayType(str, Enum): + INPUT = "input" + SLIDER = "slider" + DROPDOWN = "dropdown" + TOGGLE = "toggle" + DATE = "date" + +class ParameterType(str, Enum): + PERCENTAGE = "percentage" + PRICE = "price" + TIMESPAN = "timespan" + CONNECTOR = "connector" + TRADING_PAIR = "trading_pair" + INTEGER = "integer" + DECIMAL = "decimal" + STRING = "string" + BOOLEAN = "boolean" + +class ParameterGroup(str, Enum): + GENERAL = "General Settings" + RISK = "Risk Management" + BUY = "Buy Order Settings" + SELL = "Sell Order Settings" + DCA = "DCA Settings" + INDICATORS = "Indicator Settings" + PROFITABILITY = "Profitability Settings" + EXECUTION = "Execution Settings" + ARBITRAGE = "Arbitrage Settings" + OTHER = "Other" + +class StrategyError(Exception): + """Base class for strategy-related errors""" + +class StrategyNotFoundError(StrategyError): + """Raised when a strategy cannot be found""" + +class StrategyValidationError(StrategyError): + """Raised when strategy parameters are invalid""" + +class ParameterConstraints(BaseModel): + min_value: Optional[Union[int, float, Decimal]] = None + max_value: Optional[Union[int, float, Decimal]] = None + valid_values: Optional[List[Any]] = None + +class StrategyParameter(BaseModel): + # Core attributes + name: str + type: str + required: bool + default: Optional[Any] + + # Display attributes + display_name: str + description: str + group: ParameterGroup + is_advanced: bool + + # Validation attributes + constraints: Optional[ParameterConstraints] = None + + # UI attributes + display_type: DisplayType = DisplayType.INPUT + + # Type flags (for backward compatibility and specific handling) + parameter_type: Optional[ParameterType] = None + +class Strategy(BaseModel): + id: str + name: str + description: str + type: StrategyType + module_path: str + config_class: str + parameters: Dict[str, StrategyParameter] + class StrategyMapping(BaseModel): """Maps a strategy ID to its implementation details""" id: str # e.g., "supertrend_v1" @@ -31,27 +106,6 @@ class StrategyMapping(BaseModel): display_name: str # e.g., "Supertrend V1" description: str = "" -class StrategyParameter(BaseModel): - name: str - group: str - is_advanced: bool = False - pretty_name: str - description: str - type: str - prompt: str - default: Optional[Any] - required: bool - min_value: Optional[Union[int, float, Decimal]] = None - max_value: Optional[Union[int, float, Decimal]] = None - valid_values: Optional[List[Any]] = None - is_percentage: bool = False - is_price: bool = False - is_timespan: bool = False - is_connector: bool = False - is_trading_pair: bool = False - is_integer: bool = False - display_type: str = Field(default="input", description="Can be 'input', 'slider', 'dropdown', 'toggle', or 'date'") - class StrategyConfig(BaseModel): """Complete strategy configuration including metadata and parameters""" mapping: StrategyMapping @@ -124,106 +178,123 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: } } -def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParameter: - param = StrategyParameter( - name=name, - type=str(field.type_.__name__), - prompt=field.description if hasattr(field, 'description') else "", - default=field.default, - required=field.required or field.default is not None, - is_advanced=is_advanced_parameter(name), - group=determine_parameter_group(name), - pretty_name=name.replace('_', ' ').title(), - description="", - ) +class StrategyRegistry: + """Central registry for all trading strategies""" + + _cache: Dict[str, Strategy] = {} + + @classmethod + def _ensure_cache_loaded(cls): + if not cls._cache: + cls._cache = discover_strategies() + + @classmethod + def get_all_strategies(cls) -> Dict[str, Strategy]: + """Get all available strategies with their configurations""" + cls._ensure_cache_loaded() + return cls._cache + + @classmethod + def get_strategy(cls, strategy_id: str) -> Optional[Strategy]: + """Get a specific strategy by ID""" + cls._ensure_cache_loaded() + strategy = cls._cache.get(strategy_id) + if not strategy: + raise StrategyNotFoundError(f"Strategy '{strategy_id}' not found") + return strategy + + @classmethod + def get_strategies_by_type(cls, strategy_type: StrategyType) -> List[Strategy]: + """Get all strategies of a specific type""" + cls._ensure_cache_loaded() + return [s for s in cls._cache.values() if s.type == strategy_type] - # Get strategy display info - strategy_info = get_strategy_display_info() - - # Try to find matching strategy info - for strategy_name, info in strategy_info.items(): - if strategy_name in name.lower(): - param.pretty_name = info["pretty_name"] - param.description = info["description"] - break - - # structure of field - client_data = field.field_info.extra.get('client_data') - if client_data is not None and param.prompt == "": - desc = client_data.prompt(None) if callable(client_data.prompt) else client_data.prompt - if desc is not None: - param.prompt = desc - if not param.required: - param.required = client_data.prompt_on_new if hasattr(client_data, 'prompt_on_new') else param.required - param.display_type = "input" +def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParameter: + """Convert a model field to a strategy parameter""" + constraints = ParameterConstraints() - # Check for gt (greater than) and lt (less than) in field definition + # Handle constraints if hasattr(field.field_info, 'ge'): - param.min_value = field.field_info.ge + constraints.min_value = field.field_info.ge elif hasattr(field.field_info, 'gt'): - param.min_value = field.field_info.gt + (1 if isinstance(field.field_info.gt, int) else Decimal('0')) + constraints.min_value = field.field_info.gt + (1 if isinstance(field.field_info.gt, int) else Decimal('0')) if hasattr(field.field_info, 'le'): - param.max_value = field.field_info.le + constraints.max_value = field.field_info.le elif hasattr(field.field_info, 'lt'): - param.max_value = field.field_info.lt - (1 if isinstance(field.field_info.lt, int) else Decimal('0')) + constraints.max_value = field.field_info.lt - (1 if isinstance(field.field_info.lt, int) else Decimal('0')) + # Determine parameter type + param_type = None + if "connector" in name.lower(): + param_type = ParameterType.CONNECTOR + elif "trading_pair" in name.lower(): + param_type = ParameterType.TRADING_PAIR + elif any(word in name.lower() for word in ["percentage", "percent", "ratio", "pct"]): + param_type = ParameterType.PERCENTAGE + elif "price" in name.lower(): + param_type = ParameterType.PRICE + elif any(word in name.lower() for word in ["time", "interval", "duration"]): + param_type = ParameterType.TIMESPAN + elif str(field.type_.__name__).lower() == "int": + param_type = ParameterType.INTEGER + elif str(field.type_.__name__).lower() == "decimal": + param_type = ParameterType.DECIMAL + elif str(field.type_.__name__).lower() == "bool": + param_type = ParameterType.BOOLEAN + else: + param_type = ParameterType.STRING + + # Determine display type + display_type = DisplayType.INPUT + if constraints.min_value is not None and constraints.max_value is not None: + display_type = DisplayType.SLIDER + elif param_type == ParameterType.BOOLEAN: + display_type = DisplayType.TOGGLE + elif constraints.valid_values: + display_type = DisplayType.DROPDOWN + + # Get group + group = determine_parameter_group(name) - # Set display_type to "slider" only if both min and max values are present - if param.min_value is not None and param.max_value is not None: - param.display_type = "slider" - elif param.type == "bool": - param.display_type = "toggle" + return StrategyParameter( + name=name, + type=str(field.type_.__name__), + required=field.required or field.default is not None, + default=field.default, + display_name=name.replace('_', ' ').title(), + description=field.description if hasattr(field, 'description') else "", + group=group, + is_advanced=is_advanced_parameter(name), + constraints=constraints, + display_type=display_type, + parameter_type=param_type + ) - if "connector" in name.lower(): - param.is_connector = True - if "trading_pair" in name.lower(): - param.is_trading_pair = True - if any(word in name.lower() for word in ["percentage", "percent", "ratio", "pct"]): - param.is_percentage = True - if "price" in name.lower(): - param.is_price = True - if param.min_value is None: - param.min_value = Decimal(0) - if "amount" in name.lower(): - param.min_value = Decimal(0) - if any(word in name.lower() for word in ["time", "interval", "duration"]): - param.is_timespan = True - param.min_value = 0 - if param.type == "int": - param.is_integer = True - if any(word in name.lower() for word in ["executors", "workers"]): - param.display_type = "slider" - param.min_value = 1 - try: - if issubclass(field.type_, Enum): - param.valid_values = [item.value for item in field.type_] - param.display_type = "dropdown" - except: - pass - return param - -def determine_parameter_group(name: str) -> str: - if any(word in name.lower() for word in ["controller_name", "candles", "interval"]): - return "General Settings" - elif any(word in name.lower() for word in ["stop_loss", "trailing_stop", "take_profit", "activation_bounds", "leverage", "triple_barrier"]): - return "Risk Management" - elif "buy" in name.lower(): - return "Buy Order Settings" - elif "sell" in name.lower(): - return "Sell Order Settings" - elif "dca" in name.lower(): - return "DCA Settings" - elif any(word in name.lower() for word in ["bb", "macd", "natr", "length", "multiplier"]): - return "Indicator Settings" - elif any(word in name.lower() for word in ["profitability", "position_size"]): - return "Profitability Settings" - elif any(word in name.lower() for word in ["time_limit", "executor", "imbalance"]): - return "Execution Settings" - elif any(word in name.lower() for word in ["spot", "perp"]): - return "Arbitrage Settings" +def determine_parameter_group(name: str) -> ParameterGroup: + """Determine the parameter group based on the parameter name""" + name_lower = name.lower() + + if any(word in name_lower for word in ["controller_name", "candles", "interval"]): + return ParameterGroup.GENERAL + elif any(word in name_lower for word in ["stop_loss", "trailing_stop", "take_profit", "activation_bounds", "leverage", "triple_barrier"]): + return ParameterGroup.RISK + elif "buy" in name_lower: + return ParameterGroup.BUY + elif "sell" in name_lower: + return ParameterGroup.SELL + elif "dca" in name_lower: + return ParameterGroup.DCA + elif any(word in name_lower for word in ["bb", "macd", "natr", "length", "multiplier"]): + return ParameterGroup.INDICATORS + elif any(word in name_lower for word in ["profitability", "position_size"]): + return ParameterGroup.PROFITABILITY + elif any(word in name_lower for word in ["time_limit", "executor", "imbalance"]): + return ParameterGroup.EXECUTION + elif any(word in name_lower for word in ["spot", "perp"]): + return ParameterGroup.ARBITRAGE else: - return "Other" + return ParameterGroup.OTHER def snake_case_to_real_name(snake_case: str) -> str: return " ".join([word.capitalize() for word in snake_case.split("_")]) @@ -261,9 +332,9 @@ def generate_strategy_mapping(module_path: str, config_class: Any) -> StrategyMa ) @functools.lru_cache(maxsize=1) -def discover_strategies() -> Dict[str, StrategyConfig]: +def discover_strategies() -> Dict[str, Strategy]: """Discover and load all available strategies""" - strategy_configs = {} + strategies = {} controllers_dir = "bots/controllers" for root, dirs, files in os.walk(controllers_dir): @@ -285,18 +356,29 @@ def discover_strategies() -> Dict[str, StrategyConfig]: ): assert isinstance(obj, ModelMetaclass) - # Generate mapping - mapping = generate_strategy_mapping(module_path, obj) - + # Extract strategy ID from module path + strategy_id = module_path.split(".")[-1] + + # Get strategy type + strategy_type = infer_strategy_type(module_path, obj) + + # Get display info + display_info = get_strategy_display_info().get(strategy_id, {}) + # Convert parameters parameters = {} for field_name, field in obj.__fields__.items(): param = convert_to_strategy_parameter(field_name, field) parameters[field_name] = param - # Create complete strategy config - strategy_configs[mapping.id] = StrategyConfig( - mapping=mapping, + # Create strategy + strategies[strategy_id] = Strategy( + id=strategy_id, + name=display_info.get("pretty_name", " ".join(word.capitalize() for word in strategy_id.split("_"))), + description=display_info.get("description", obj.__doc__ or ""), + type=strategy_type, + module_path=module_path, + config_class=obj.__name__, parameters=parameters ) @@ -307,20 +389,37 @@ def discover_strategies() -> Dict[str, StrategyConfig]: import traceback traceback.print_exc() - return strategy_configs + return strategies def get_strategy_mapping(strategy_id: str) -> Optional[StrategyMapping]: - """Get strategy mapping by ID""" - strategies = discover_strategies() - strategy = strategies.get(strategy_id) - return strategy.mapping if strategy else None + """ + DEPRECATED: Use StrategyRegistry.get_strategy() instead + Get strategy mapping by ID + """ + logger.warning("get_strategy_mapping is deprecated. Use StrategyRegistry.get_strategy() instead") + strategy = StrategyRegistry.get_strategy(strategy_id) + if not strategy: + return None + return StrategyMapping( + id=strategy.id, + config_class=strategy.config_class, + module_path=strategy.module_path, + strategy_type=strategy.type, + display_name=strategy.name, + description=strategy.description + ) def get_strategy_module_path(strategy_id: str) -> Optional[str]: - """Get the module path for a strategy""" - mapping = get_strategy_mapping(strategy_id) - return mapping.module_path if mapping else None + """ + DEPRECATED: Use StrategyRegistry.get_strategy().module_path instead + Get the module path for a strategy + """ + logger.warning("get_strategy_module_path is deprecated. Use StrategyRegistry.get_strategy().module_path instead") + strategy = StrategyRegistry.get_strategy(strategy_id) + return strategy.module_path if strategy else None def is_advanced_parameter(name: str) -> bool: + """Determine if a parameter should be considered advanced""" advanced_keywords = [ "activation_bounds", "triple_barrier", "leverage", "dca", "macd", "natr", "multiplier", "imbalance", "executor", "perp", "arbitrage" From 6034f67bf80e5a7bed3a01e1a76d8b1329fb57a1 Mon Sep 17 00:00:00 2001 From: MHHukiewitz Date: Mon, 16 Dec 2024 18:50:32 +0100 Subject: [PATCH 5/5] Improve error handling and API messages; add validation of strategy parameters --- routers/backtest.py | 104 +++++++++++++----- routers/bots.py | 238 +++++++++++++++++++++++++++++------------- routers/strategies.py | 70 ++++++++++--- 3 files changed, 295 insertions(+), 117 deletions(-) diff --git a/routers/backtest.py b/routers/backtest.py index ec78fd8..1b080b6 100644 --- a/routers/backtest.py +++ b/routers/backtest.py @@ -8,6 +8,7 @@ from config import CONTROLLERS_MODULE, CONTROLLERS_PATH from routers.backtest_models import BacktestResponse, BacktestResults, BacktestingConfig, ExecutorInfo, ProcessedData +from routers.strategies_models import StrategyError router = APIRouter(tags=["Market Backtesting"]) candles_factory = CandlesFactory() @@ -19,37 +20,84 @@ "market_making": market_making_backtesting } +class BacktestError(StrategyError): + """Base class for backtesting-related errors""" + +class BacktestConfigError(BacktestError): + """Raised when there's an error in the backtesting configuration""" + +class BacktestEngineError(BacktestError): + """Raised when there's an error during backtesting execution""" + @router.post("/backtest", response_model=BacktestResponse) async def run_backtesting(backtesting_config: BacktestingConfig) -> BacktestResponse: try: - if isinstance(backtesting_config.config, str): - controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( - config_path=backtesting_config.config, - controllers_conf_dir_path=CONTROLLERS_PATH, - controllers_module=CONTROLLERS_MODULE - ) - else: - controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( - config_data=backtesting_config.config, - controllers_module=CONTROLLERS_MODULE - ) + # Load and validate controller config + try: + if isinstance(backtesting_config.config, str): + controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( + config_path=backtesting_config.config, + controllers_conf_dir_path=CONTROLLERS_PATH, + controllers_module=CONTROLLERS_MODULE + ) + else: + controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( + config_data=backtesting_config.config, + controllers_module=CONTROLLERS_MODULE + ) + except Exception as e: + raise BacktestConfigError(f"Invalid controller configuration: {str(e)}") + + # Get and validate backtesting engine backtesting_engine = BACKTESTING_ENGINES.get(controller_config.controller_type) if not backtesting_engine: - raise ValueError(f"Backtesting engine for controller type {controller_config.controller_type} not found.") - backtesting_results = await backtesting_engine.run_backtesting( - controller_config=controller_config, trade_cost=backtesting_config.trade_cost, - start=int(backtesting_config.start_time), end=int(backtesting_config.end_time), - backtesting_resolution=backtesting_config.backtesting_resolution) - - processed_data = backtesting_results["processed_data"]["features"].fillna(0).to_dict() - executors_info = [ExecutorInfo(**e.to_dict()) for e in backtesting_results["executors"]] - results = backtesting_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - - return BacktestResponse( - executors=executors_info, - processed_data=ProcessedData(features=processed_data), - results=BacktestResults(**results) - ) + raise BacktestConfigError( + f"Backtesting engine for controller type {controller_config.controller_type} not found. " + f"Available types: {list(BACKTESTING_ENGINES.keys())}" + ) + + # Validate time range + if backtesting_config.end_time <= backtesting_config.start_time: + raise BacktestConfigError( + f"Invalid time range: end_time ({backtesting_config.end_time}) must be greater than " + f"start_time ({backtesting_config.start_time})" + ) + + try: + # Run backtesting + backtesting_results = await backtesting_engine.run_backtesting( + controller_config=controller_config, + trade_cost=backtesting_config.trade_cost, + start=int(backtesting_config.start_time), + end=int(backtesting_config.end_time), + backtesting_resolution=backtesting_config.backtesting_resolution + ) + except Exception as e: + raise BacktestEngineError(f"Error during backtesting execution: {str(e)}") + + try: + # Process results + processed_data = backtesting_results["processed_data"]["features"].fillna(0).to_dict() + executors_info = [ExecutorInfo(**e.to_dict()) for e in backtesting_results["executors"]] + results = backtesting_results["results"] + results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 + + return BacktestResponse( + executors=executors_info, + processed_data=ProcessedData(features=processed_data), + results=BacktestResults(**results) + ) + except Exception as e: + raise BacktestError(f"Error processing backtesting results: {str(e)}") + + except BacktestConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BacktestEngineError as e: + raise HTTPException(status_code=500, detail=str(e)) + except BacktestError as e: + raise HTTPException(status_code=500, detail=str(e)) except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file + raise HTTPException( + status_code=500, + detail=f"Unexpected error during backtesting: {str(e)}" + ) \ No newline at end of file diff --git a/routers/bots.py b/routers/bots.py index ed653ae..f3fbf63 100644 --- a/routers/bots.py +++ b/routers/bots.py @@ -5,103 +5,191 @@ from fastapi_walletauth import JWTWalletAuthDep from utils.models import HummingbotInstanceConfig, StartStrategyRequest, InstanceResponse from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient +from routers.strategies_models import ( + StrategyError, + StrategyRegistry, + StrategyNotFoundError, + StrategyValidationError +) router = APIRouter(tags=["Bot Management"]) accounts_service = AccountsService() docker_manager = DockerManager() gateway_client = GatewayHttpClient.get_instance() +class BotError(StrategyError): + """Base class for bot-related errors""" + +class BotNotFoundError(BotError): + """Raised when a bot cannot be found""" + +class BotPermissionError(BotError): + """Raised when there's a permission error with bot operations""" + +class BotConfigError(BotError): + """Raised when there's an error in bot configuration""" class CreateBotRequest(BaseModel): strategy_name: str strategy_parameters: dict market: str - @router.post("/bots", response_model=InstanceResponse) async def create_bot(request: CreateBotRequest, wallet_auth: JWTWalletAuthDep): - bot_account = f"robotter_{wallet_auth.address}_{request.market}_{request.strategy_name}" - accounts_service.add_account(bot_account) - wallet_address = await accounts_service.generate_bot_wallet(bot_account) - - # Save strategy configuration and market - bot_config = BotConfig( - strategy_name=request.strategy_name, - parameters=request.strategy_parameters, - market=request.market, - wallet_address=wallet_auth.address, - ) - accounts_service.save_bot_config(bot_account, bot_config) - # Create Hummingbot instance - instance_config = HummingbotInstanceConfig( - instance_name=bot_account, credentials_profile=bot_account, image="mlguys/hummingbot:mango", market=request.market - ) - result = docker_manager.create_hummingbot_instance(instance_config) - - if not result["success"]: - raise HTTPException(status_code=500, detail=result["message"]) - - return InstanceResponse(instance_id=bot_account, wallet_address=wallet_address, market=request.market) - + try: + # Validate strategy exists and parameters + try: + strategy = StrategyRegistry.get_strategy(request.strategy_name) + except StrategyNotFoundError as e: + raise BotConfigError(f"Invalid strategy: {str(e)}") + + try: + validate_strategy_parameters(strategy, request.strategy_parameters) + except StrategyValidationError as e: + raise BotConfigError(f"Invalid strategy parameters: {str(e)}") + + # Create bot account + bot_account = f"robotter_{wallet_auth.address}_{request.market}_{request.strategy_name}" + try: + accounts_service.add_account(bot_account) + wallet_address = await accounts_service.generate_bot_wallet(bot_account) + except Exception as e: + raise BotError(f"Error creating bot account: {str(e)}") + + # Save strategy configuration and market + try: + bot_config = BotConfig( + strategy_name=request.strategy_name, + parameters=request.strategy_parameters, + market=request.market, + wallet_address=wallet_auth.address, + ) + accounts_service.save_bot_config(bot_account, bot_config) + except Exception as e: + raise BotConfigError(f"Error saving bot configuration: {str(e)}") + + # Create Hummingbot instance + try: + instance_config = HummingbotInstanceConfig( + instance_name=bot_account, + credentials_profile=bot_account, + image="mlguys/hummingbot:mango", + market=request.market + ) + result = docker_manager.create_hummingbot_instance(instance_config) + + if not result["success"]: + raise BotError(result["message"]) + except Exception as e: + raise BotError(f"Error creating Hummingbot instance: {str(e)}") + + return InstanceResponse( + instance_id=bot_account, + wallet_address=wallet_address, + market=request.market + ) + + except BotConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error creating bot: {str(e)}" + ) @router.get("/bots/{bot_id}/wallet") async def get_bot_wallet(bot_id: str, wallet_auth: JWTWalletAuthDep): try: # Check if the bot belongs to the authenticated user if not bot_id.endswith(wallet_auth.address): - raise HTTPException(status_code=403, detail="You don't have permission to access this bot") + raise BotPermissionError("You don't have permission to access this bot") - wallet_address = accounts_service.get_bot_wallet_address(bot_id) - return {"wallet_address": wallet_address} - except Exception as e: - raise HTTPException(status_code=404, detail=str(e)) + try: + wallet_address = accounts_service.get_bot_wallet_address(bot_id) + return {"wallet_address": wallet_address} + except Exception as e: + raise BotNotFoundError(f"Bot wallet not found: {str(e)}") + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error getting bot wallet: {str(e)}" + ) @router.post("/bots/{bot_id}/start") async def start_bot(bot_id: str, start_request: StartStrategyRequest, wallet_auth: JWTWalletAuthDep): - # Check if the bot belongs to the authenticated user - bot_config = accounts_service.get_bot_config(bot_id) - if not bot_config or bot_config.wallet_address != wallet_auth.address: - raise HTTPException(status_code=403, detail="You don't have permission to start this bot") - - # Check if Mango account exists and is associated with the bot's wallet - bot_wallet = accounts_service.get_bot_wallet_address(bot_id) - - # We should pass the wallet address - mango_account_info = await gateway_client.get_mango_account( - "solana", "mainnet", "mango_perpetual_solana_mainnet-beta", bot_wallet - ) - - if not mango_account_info or mango_account_info.get("owner") != bot_wallet: - raise HTTPException(status_code=400, detail="Invalid Mango account or not associated with the bot's wallet") - - # Start the bot - strategy_config = accounts_service.get_strategy_config(bot_id) - start_config = {**strategy_config, **start_request.parameters} - - response = docker_manager.start_bot(bot_id, start_config) - if not response["success"]: - raise HTTPException(status_code=500, detail="Failed to start the bot") - - return {"status": "success", "message": "Bot started successfully"} - - -@router.post("/bots/{bot_id}/stop") -async def stop_bot(bot_id: str, wallet_auth: JWTWalletAuthDep): - # Check if the bot belongs to the authenticated user - if not bot_id.endswith(wallet_auth.address): - raise HTTPException(status_code=403, detail="You don't have permission to stop this bot") - - # Stop the bot and cancel all orders - response = docker_manager.stop_bot(bot_id) - if not response["success"]: - raise HTTPException(status_code=500, detail="Failed to stop the bot") - - # Cancel all orders through the gateway - bot_wallet = accounts_service.get_bot_wallet_address(bot_id) - cancel_orders_response = await gateway_client.clob_perp_get_orders("solana", "mainnet", "mango_perpetual_solana_mainnet-beta") - - if not cancel_orders_response["success"]: - raise HTTPException(status_code=500, detail="Failed to cancel all orders") - - return {"status": "success", "message": "Bot stopped and all orders cancelled successfully"} \ No newline at end of file + try: + # Check if the bot belongs to the authenticated user + bot_config = accounts_service.get_bot_config(bot_id) + if not bot_config or bot_config.wallet_address != wallet_auth.address: + raise BotPermissionError("You don't have permission to start this bot") + + # Check if Mango account exists and is associated with the bot's wallet + try: + bot_wallet = accounts_service.get_bot_wallet_address(bot_id) + mango_account_info = await gateway_client.get_mango_account( + "solana", "mainnet", "mango_perpetual_solana_mainnet-beta", bot_wallet + ) + + if not mango_account_info or mango_account_info.get("owner") != bot_wallet: + raise BotConfigError("Invalid Mango account or not associated with the bot's wallet") + except Exception as e: + raise BotConfigError(f"Error validating Mango account: {str(e)}") + + # Start the bot + try: + strategy_config = accounts_service.get_strategy_config(bot_id) + start_config = {**strategy_config, **start_request.parameters} + + response = docker_manager.start_bot(bot_id, start_config) + if not response["success"]: + raise BotError("Failed to start the bot") + + return {"status": "success", "message": "Bot started successfully"} + except Exception as e: + raise BotError(f"Error starting bot: {str(e)}") + + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BotError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error starting bot: {str(e)}" + ) + +def validate_strategy_parameters(strategy, parameters: dict) -> None: + """Validate strategy parameters against their constraints""" + for param_name, value in parameters.items(): + if param_name not in strategy.parameters: + raise StrategyValidationError(f"Unknown parameter: {param_name}") + + param = strategy.parameters[param_name] + constraints = param.constraints + + if constraints: + if constraints.min_value is not None and value < constraints.min_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is below minimum {constraints.min_value}" + ) + + if constraints.max_value is not None and value > constraints.max_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is above maximum {constraints.max_value}" + ) + + if constraints.valid_values is not None and value not in constraints.valid_values: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is not one of the valid values: {constraints.valid_values}" + ) \ No newline at end of file diff --git a/routers/strategies.py b/routers/strategies.py index 9366d77..dd9efe8 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any from fastapi import APIRouter, HTTPException from fastapi import FastAPI from contextlib import asynccontextmanager @@ -8,7 +8,10 @@ ParameterSuggestionRequest, ParameterSuggestionResponse, StrategyConfig, - discover_strategies + StrategyRegistry, + StrategyNotFoundError, + StrategyValidationError, + StrategyError ) # Create a libert_ai_service instance @@ -20,7 +23,7 @@ async def lifespan(app: FastAPI): try: # Load strategies using auto-discovery print("Initializing LibertAI contexts...") - strategies = discover_strategies() + strategies = StrategyRegistry.get_all_strategies() await libert_ai_service.initialize_contexts(strategies) print(f"Successfully initialized contexts for {len(strategies)} strategies") except Exception as e: @@ -39,10 +42,14 @@ async def lifespan(app: FastAPI): async def get_strategies() -> Dict[str, StrategyConfig]: """Get all available strategies and their configurations.""" try: - # Use auto-discovery to get strategies - return discover_strategies() + return StrategyRegistry.get_all_strategies() + except StrategyError as e: + raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail=f"Internal server error while fetching strategies: {str(e)}" + ) @router.post("/strategies/suggest-parameters") async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSuggestionResponse: @@ -54,13 +61,17 @@ async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSu Otherwise, will suggest values for all missing required parameters. """ try: - # Get strategy configuration using auto-discovery - strategies = discover_strategies() - - if request.strategy_id not in strategies: - raise HTTPException(status_code=404, detail=f"Strategy '{request.strategy_id}' not found") + # Get strategy using the registry + try: + strategy = StrategyRegistry.get_strategy(request.strategy_id) + except StrategyNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) - strategy = strategies[request.strategy_id] + # Validate parameters against constraints + try: + validate_parameters(strategy, request.parameters) + except StrategyValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) # Ensure context is initialized for this strategy if request.strategy_id not in libert_ai_service.strategy_slot_map: @@ -90,9 +101,40 @@ async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSu ) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail=f"Error getting parameter suggestions: {str(e)}" + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail=f"Unexpected error processing parameter suggestions: {str(e)}" + ) + +def validate_parameters(strategy: StrategyConfig, parameters: Dict[str, Any]) -> None: + """Validate parameters against their constraints""" + for param_name, value in parameters.items(): + if param_name not in strategy.parameters: + raise StrategyValidationError(f"Unknown parameter: {param_name}") + + param = strategy.parameters[param_name] + constraints = param.constraints + + if constraints: + if constraints.min_value is not None and value < constraints.min_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is below minimum {constraints.min_value}" + ) + + if constraints.max_value is not None and value > constraints.max_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is above maximum {constraints.max_value}" + ) + + if constraints.valid_values is not None and value not in constraints.valid_values: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is not one of the valid values: {constraints.valid_values}" + )