diff --git a/conf/conf_client.yml b/conf/conf_client.yml index dd8e81e..856639f 100644 --- a/conf/conf_client.yml +++ b/conf/conf_client.yml @@ -107,7 +107,7 @@ gateway: gateway_api_host: localhost gateway_api_port: '15888' -certs_path: ./certs +certs_path: certs # Whether to enable aggregated order and trade data collection anonymized_metrics_mode: diff --git a/pyproject.toml b/pyproject.toml index 4ad399a..ea36189 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,26 @@ [build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.black] -line-length = 130 -target-version = ["py38"] - -[tool.isort] -line_length = 130 -profile = "black" -multi_line_output = 3 -include_trailing_comma = true -use_parentheses = true -ensure_newline_before_comments = true -combine_as_imports = true +[project] +name = "robotter-backend-api" +version = "0.1.0" +authors = [ + { name="Mike Henry" }, +] +description = "Robotter Backend API" +readme = "README.md" +requires-python = ">=3.10" -[tool.pre-commit] -repos = [ - { repo = "https://github.com/pre-commit/pre-commit-hooks", rev = "v3.4.0", hooks = [{ id = "check-yaml" }, { id = "end-of-file-fixer" }] }, - { repo = "https://github.com/psf/black", rev = "21.6b0", hooks = [{ id = "black" }] }, - { repo = "https://github.com/pre-commit/mirrors-isort", rev = "v5.9.3", hooks = [{ id = "isort" }] } +[tool.pytest.ini_options] +pythonpath = [ + "." +] +asyncio_mode = "strict" +testpaths = [ + "tests", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", ] diff --git a/routers/__init__.py b/routers/__init__.py index e69de29..3ca04e5 100644 --- a/routers/__init__.py +++ b/routers/__init__.py @@ -0,0 +1,3 @@ +""" +Router package initialization. +""" diff --git a/routers/backtest.py b/routers/backtest.py index ec78fd8..1b080b6 100644 --- a/routers/backtest.py +++ b/routers/backtest.py @@ -8,6 +8,7 @@ from config import CONTROLLERS_MODULE, CONTROLLERS_PATH from routers.backtest_models import BacktestResponse, BacktestResults, BacktestingConfig, ExecutorInfo, ProcessedData +from routers.strategies_models import StrategyError router = APIRouter(tags=["Market Backtesting"]) candles_factory = CandlesFactory() @@ -19,37 +20,84 @@ "market_making": market_making_backtesting } +class BacktestError(StrategyError): + """Base class for backtesting-related errors""" + +class BacktestConfigError(BacktestError): + """Raised when there's an error in the backtesting configuration""" + +class BacktestEngineError(BacktestError): + """Raised when there's an error during backtesting execution""" + @router.post("/backtest", response_model=BacktestResponse) async def run_backtesting(backtesting_config: BacktestingConfig) -> BacktestResponse: try: - if isinstance(backtesting_config.config, str): - controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( - config_path=backtesting_config.config, - controllers_conf_dir_path=CONTROLLERS_PATH, - controllers_module=CONTROLLERS_MODULE - ) - else: - controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( - config_data=backtesting_config.config, - controllers_module=CONTROLLERS_MODULE - ) + # Load and validate controller config + try: + if isinstance(backtesting_config.config, str): + controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( + config_path=backtesting_config.config, + controllers_conf_dir_path=CONTROLLERS_PATH, + controllers_module=CONTROLLERS_MODULE + ) + else: + controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( + config_data=backtesting_config.config, + controllers_module=CONTROLLERS_MODULE + ) + except Exception as e: + raise BacktestConfigError(f"Invalid controller configuration: {str(e)}") + + # Get and validate backtesting engine backtesting_engine = BACKTESTING_ENGINES.get(controller_config.controller_type) if not backtesting_engine: - raise ValueError(f"Backtesting engine for controller type {controller_config.controller_type} not found.") - backtesting_results = await backtesting_engine.run_backtesting( - controller_config=controller_config, trade_cost=backtesting_config.trade_cost, - start=int(backtesting_config.start_time), end=int(backtesting_config.end_time), - backtesting_resolution=backtesting_config.backtesting_resolution) - - processed_data = backtesting_results["processed_data"]["features"].fillna(0).to_dict() - executors_info = [ExecutorInfo(**e.to_dict()) for e in backtesting_results["executors"]] - results = backtesting_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - - return BacktestResponse( - executors=executors_info, - processed_data=ProcessedData(features=processed_data), - results=BacktestResults(**results) - ) + raise BacktestConfigError( + f"Backtesting engine for controller type {controller_config.controller_type} not found. " + f"Available types: {list(BACKTESTING_ENGINES.keys())}" + ) + + # Validate time range + if backtesting_config.end_time <= backtesting_config.start_time: + raise BacktestConfigError( + f"Invalid time range: end_time ({backtesting_config.end_time}) must be greater than " + f"start_time ({backtesting_config.start_time})" + ) + + try: + # Run backtesting + backtesting_results = await backtesting_engine.run_backtesting( + controller_config=controller_config, + trade_cost=backtesting_config.trade_cost, + start=int(backtesting_config.start_time), + end=int(backtesting_config.end_time), + backtesting_resolution=backtesting_config.backtesting_resolution + ) + except Exception as e: + raise BacktestEngineError(f"Error during backtesting execution: {str(e)}") + + try: + # Process results + processed_data = backtesting_results["processed_data"]["features"].fillna(0).to_dict() + executors_info = [ExecutorInfo(**e.to_dict()) for e in backtesting_results["executors"]] + results = backtesting_results["results"] + results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 + + return BacktestResponse( + executors=executors_info, + processed_data=ProcessedData(features=processed_data), + results=BacktestResults(**results) + ) + except Exception as e: + raise BacktestError(f"Error processing backtesting results: {str(e)}") + + except BacktestConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BacktestEngineError as e: + raise HTTPException(status_code=500, detail=str(e)) + except BacktestError as e: + raise HTTPException(status_code=500, detail=str(e)) except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file + raise HTTPException( + status_code=500, + detail=f"Unexpected error during backtesting: {str(e)}" + ) \ No newline at end of file diff --git a/routers/bots.py b/routers/bots.py index ed653ae..f3fbf63 100644 --- a/routers/bots.py +++ b/routers/bots.py @@ -5,103 +5,191 @@ from fastapi_walletauth import JWTWalletAuthDep from utils.models import HummingbotInstanceConfig, StartStrategyRequest, InstanceResponse from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient +from routers.strategies_models import ( + StrategyError, + StrategyRegistry, + StrategyNotFoundError, + StrategyValidationError +) router = APIRouter(tags=["Bot Management"]) accounts_service = AccountsService() docker_manager = DockerManager() gateway_client = GatewayHttpClient.get_instance() +class BotError(StrategyError): + """Base class for bot-related errors""" + +class BotNotFoundError(BotError): + """Raised when a bot cannot be found""" + +class BotPermissionError(BotError): + """Raised when there's a permission error with bot operations""" + +class BotConfigError(BotError): + """Raised when there's an error in bot configuration""" class CreateBotRequest(BaseModel): strategy_name: str strategy_parameters: dict market: str - @router.post("/bots", response_model=InstanceResponse) async def create_bot(request: CreateBotRequest, wallet_auth: JWTWalletAuthDep): - bot_account = f"robotter_{wallet_auth.address}_{request.market}_{request.strategy_name}" - accounts_service.add_account(bot_account) - wallet_address = await accounts_service.generate_bot_wallet(bot_account) - - # Save strategy configuration and market - bot_config = BotConfig( - strategy_name=request.strategy_name, - parameters=request.strategy_parameters, - market=request.market, - wallet_address=wallet_auth.address, - ) - accounts_service.save_bot_config(bot_account, bot_config) - # Create Hummingbot instance - instance_config = HummingbotInstanceConfig( - instance_name=bot_account, credentials_profile=bot_account, image="mlguys/hummingbot:mango", market=request.market - ) - result = docker_manager.create_hummingbot_instance(instance_config) - - if not result["success"]: - raise HTTPException(status_code=500, detail=result["message"]) - - return InstanceResponse(instance_id=bot_account, wallet_address=wallet_address, market=request.market) - + try: + # Validate strategy exists and parameters + try: + strategy = StrategyRegistry.get_strategy(request.strategy_name) + except StrategyNotFoundError as e: + raise BotConfigError(f"Invalid strategy: {str(e)}") + + try: + validate_strategy_parameters(strategy, request.strategy_parameters) + except StrategyValidationError as e: + raise BotConfigError(f"Invalid strategy parameters: {str(e)}") + + # Create bot account + bot_account = f"robotter_{wallet_auth.address}_{request.market}_{request.strategy_name}" + try: + accounts_service.add_account(bot_account) + wallet_address = await accounts_service.generate_bot_wallet(bot_account) + except Exception as e: + raise BotError(f"Error creating bot account: {str(e)}") + + # Save strategy configuration and market + try: + bot_config = BotConfig( + strategy_name=request.strategy_name, + parameters=request.strategy_parameters, + market=request.market, + wallet_address=wallet_auth.address, + ) + accounts_service.save_bot_config(bot_account, bot_config) + except Exception as e: + raise BotConfigError(f"Error saving bot configuration: {str(e)}") + + # Create Hummingbot instance + try: + instance_config = HummingbotInstanceConfig( + instance_name=bot_account, + credentials_profile=bot_account, + image="mlguys/hummingbot:mango", + market=request.market + ) + result = docker_manager.create_hummingbot_instance(instance_config) + + if not result["success"]: + raise BotError(result["message"]) + except Exception as e: + raise BotError(f"Error creating Hummingbot instance: {str(e)}") + + return InstanceResponse( + instance_id=bot_account, + wallet_address=wallet_address, + market=request.market + ) + + except BotConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error creating bot: {str(e)}" + ) @router.get("/bots/{bot_id}/wallet") async def get_bot_wallet(bot_id: str, wallet_auth: JWTWalletAuthDep): try: # Check if the bot belongs to the authenticated user if not bot_id.endswith(wallet_auth.address): - raise HTTPException(status_code=403, detail="You don't have permission to access this bot") + raise BotPermissionError("You don't have permission to access this bot") - wallet_address = accounts_service.get_bot_wallet_address(bot_id) - return {"wallet_address": wallet_address} - except Exception as e: - raise HTTPException(status_code=404, detail=str(e)) + try: + wallet_address = accounts_service.get_bot_wallet_address(bot_id) + return {"wallet_address": wallet_address} + except Exception as e: + raise BotNotFoundError(f"Bot wallet not found: {str(e)}") + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error getting bot wallet: {str(e)}" + ) @router.post("/bots/{bot_id}/start") async def start_bot(bot_id: str, start_request: StartStrategyRequest, wallet_auth: JWTWalletAuthDep): - # Check if the bot belongs to the authenticated user - bot_config = accounts_service.get_bot_config(bot_id) - if not bot_config or bot_config.wallet_address != wallet_auth.address: - raise HTTPException(status_code=403, detail="You don't have permission to start this bot") - - # Check if Mango account exists and is associated with the bot's wallet - bot_wallet = accounts_service.get_bot_wallet_address(bot_id) - - # We should pass the wallet address - mango_account_info = await gateway_client.get_mango_account( - "solana", "mainnet", "mango_perpetual_solana_mainnet-beta", bot_wallet - ) - - if not mango_account_info or mango_account_info.get("owner") != bot_wallet: - raise HTTPException(status_code=400, detail="Invalid Mango account or not associated with the bot's wallet") - - # Start the bot - strategy_config = accounts_service.get_strategy_config(bot_id) - start_config = {**strategy_config, **start_request.parameters} - - response = docker_manager.start_bot(bot_id, start_config) - if not response["success"]: - raise HTTPException(status_code=500, detail="Failed to start the bot") - - return {"status": "success", "message": "Bot started successfully"} - - -@router.post("/bots/{bot_id}/stop") -async def stop_bot(bot_id: str, wallet_auth: JWTWalletAuthDep): - # Check if the bot belongs to the authenticated user - if not bot_id.endswith(wallet_auth.address): - raise HTTPException(status_code=403, detail="You don't have permission to stop this bot") - - # Stop the bot and cancel all orders - response = docker_manager.stop_bot(bot_id) - if not response["success"]: - raise HTTPException(status_code=500, detail="Failed to stop the bot") - - # Cancel all orders through the gateway - bot_wallet = accounts_service.get_bot_wallet_address(bot_id) - cancel_orders_response = await gateway_client.clob_perp_get_orders("solana", "mainnet", "mango_perpetual_solana_mainnet-beta") - - if not cancel_orders_response["success"]: - raise HTTPException(status_code=500, detail="Failed to cancel all orders") - - return {"status": "success", "message": "Bot stopped and all orders cancelled successfully"} \ No newline at end of file + try: + # Check if the bot belongs to the authenticated user + bot_config = accounts_service.get_bot_config(bot_id) + if not bot_config or bot_config.wallet_address != wallet_auth.address: + raise BotPermissionError("You don't have permission to start this bot") + + # Check if Mango account exists and is associated with the bot's wallet + try: + bot_wallet = accounts_service.get_bot_wallet_address(bot_id) + mango_account_info = await gateway_client.get_mango_account( + "solana", "mainnet", "mango_perpetual_solana_mainnet-beta", bot_wallet + ) + + if not mango_account_info or mango_account_info.get("owner") != bot_wallet: + raise BotConfigError("Invalid Mango account or not associated with the bot's wallet") + except Exception as e: + raise BotConfigError(f"Error validating Mango account: {str(e)}") + + # Start the bot + try: + strategy_config = accounts_service.get_strategy_config(bot_id) + start_config = {**strategy_config, **start_request.parameters} + + response = docker_manager.start_bot(bot_id, start_config) + if not response["success"]: + raise BotError("Failed to start the bot") + + return {"status": "success", "message": "Bot started successfully"} + except Exception as e: + raise BotError(f"Error starting bot: {str(e)}") + + except BotPermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + except BotConfigError as e: + raise HTTPException(status_code=400, detail=str(e)) + except BotError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error starting bot: {str(e)}" + ) + +def validate_strategy_parameters(strategy, parameters: dict) -> None: + """Validate strategy parameters against their constraints""" + for param_name, value in parameters.items(): + if param_name not in strategy.parameters: + raise StrategyValidationError(f"Unknown parameter: {param_name}") + + param = strategy.parameters[param_name] + constraints = param.constraints + + if constraints: + if constraints.min_value is not None and value < constraints.min_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is below minimum {constraints.min_value}" + ) + + if constraints.max_value is not None and value > constraints.max_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is above maximum {constraints.max_value}" + ) + + if constraints.valid_values is not None and value not in constraints.valid_values: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is not one of the valid values: {constraints.valid_values}" + ) \ No newline at end of file diff --git a/routers/strategies.py b/routers/strategies.py index ce0fff3..dd9efe8 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -1,11 +1,140 @@ -from typing import Dict +from typing import Dict, Any +from fastapi import APIRouter, HTTPException +from fastapi import FastAPI +from contextlib import asynccontextmanager -from fastapi import APIRouter -from .strategies_models import StrategyParameter, get_all_strategy_maps +from services.libert_ai_service import LibertAIService +from routers.strategies_models import ( + ParameterSuggestionRequest, + ParameterSuggestionResponse, + StrategyConfig, + StrategyRegistry, + StrategyNotFoundError, + StrategyValidationError, + StrategyError +) -router = APIRouter(tags=["Strategies"]) +# Create a libert_ai_service instance +libert_ai_service = LibertAIService() +@asynccontextmanager +async def lifespan(app: FastAPI): + # Initialize contexts on startup + try: + # Load strategies using auto-discovery + print("Initializing LibertAI contexts...") + strategies = StrategyRegistry.get_all_strategies() + await libert_ai_service.initialize_contexts(strategies) + print(f"Successfully initialized contexts for {len(strategies)} strategies") + except Exception as e: + print(f"Error initializing LibertAI contexts: {str(e)}") + # Re-raise the exception to prevent app startup if context initialization fails + raise + yield + # Cleanup on shutdown if needed + print("Cleaning up LibertAI contexts...") -@router.get("/strategies", response_model=Dict[str, Dict[str, StrategyParameter]]) -async def get_strategies(): - return get_all_strategy_maps() +# Create the FastAPI app with the lifespan handler +app = FastAPI(lifespan=lifespan) +router = APIRouter() + +@router.get("/strategies") +async def get_strategies() -> Dict[str, StrategyConfig]: + """Get all available strategies and their configurations.""" + try: + return StrategyRegistry.get_all_strategies() + except StrategyError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Internal server error while fetching strategies: {str(e)}" + ) + +@router.post("/strategies/suggest-parameters") +async def suggest_parameters(request: ParameterSuggestionRequest) -> ParameterSuggestionResponse: + """ + Suggest parameter values for a strategy based on the provided parameters. + Uses LibertAI to analyze and suggest optimal values for missing or requested parameters. + + If requested_parameters is provided, will only suggest values for those specific parameters. + Otherwise, will suggest values for all missing required parameters. + """ + try: + # Get strategy using the registry + try: + strategy = StrategyRegistry.get_strategy(request.strategy_id) + except StrategyNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + # Validate parameters against constraints + try: + validate_parameters(strategy, request.parameters) + except StrategyValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Ensure context is initialized for this strategy + if request.strategy_id not in libert_ai_service.strategy_slot_map: + print(f"Re-initializing context for strategy {request.strategy_id}") + await libert_ai_service.initialize_contexts({request.strategy_id: strategy}) + + try: + # Get suggestions from LibertAI + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id=request.strategy_id, + strategy_config=strategy.parameters, + provided_params=request.parameters, + requested_params=request.requested_parameters + ) + + # Extract summary from the last suggestion if it exists + summary = "No suggestions available." + if suggestions: + # Remove the summary suggestion if it exists + if suggestions[-1].parameter_name.lower() == "summary": + summary = suggestions[-1].reasoning + suggestions = suggestions[:-1] + + return ParameterSuggestionResponse( + suggestions=suggestions, + summary=summary + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error getting parameter suggestions: {str(e)}" + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error processing parameter suggestions: {str(e)}" + ) + +def validate_parameters(strategy: StrategyConfig, parameters: Dict[str, Any]) -> None: + """Validate parameters against their constraints""" + for param_name, value in parameters.items(): + if param_name not in strategy.parameters: + raise StrategyValidationError(f"Unknown parameter: {param_name}") + + param = strategy.parameters[param_name] + constraints = param.constraints + + if constraints: + if constraints.min_value is not None and value < constraints.min_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is below minimum {constraints.min_value}" + ) + + if constraints.max_value is not None and value > constraints.max_value: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is above maximum {constraints.max_value}" + ) + + if constraints.valid_values is not None and value not in constraints.valid_values: + raise StrategyValidationError( + f"Parameter {param_name} value {value} is not one of the valid values: {constraints.valid_values}" + ) diff --git a/routers/strategies_models.py b/routers/strategies_models.py index 033b7b5..070ad7a 100644 --- a/routers/strategies_models.py +++ b/routers/strategies_models.py @@ -1,6 +1,5 @@ from enum import Enum from typing import Any, Dict, Optional, Union, List - from pydantic import BaseModel, Field from decimal import Decimal from hummingbot.strategy_v2.controllers import MarketMakingControllerConfigBase, ControllerConfigBase, DirectionalTradingControllerConfigBase @@ -12,56 +11,124 @@ from pydantic.fields import ModelField from pydantic.main import ModelMetaclass -from bots.controllers.directional_trading.bollinger_v1 import BollingerV1ControllerConfig - logger = ( logging.getLogger(__name__) if __name__ != "__main__" else logging.getLogger("uvicorn") ) -class StrategyParameter(BaseModel): - name: str - group: str - is_advanced: bool = False - pretty_name: str - description: str - type: str - prompt: str - default: Optional[Any] - required: bool +class StrategyType(str, Enum): + DIRECTIONAL_TRADING = "directional_trading" + MARKET_MAKING = "market_making" + GENERIC = "generic" + +class DisplayType(str, Enum): + INPUT = "input" + SLIDER = "slider" + DROPDOWN = "dropdown" + TOGGLE = "toggle" + DATE = "date" + +class ParameterType(str, Enum): + PERCENTAGE = "percentage" + PRICE = "price" + TIMESPAN = "timespan" + CONNECTOR = "connector" + TRADING_PAIR = "trading_pair" + INTEGER = "integer" + DECIMAL = "decimal" + STRING = "string" + BOOLEAN = "boolean" + +class ParameterGroup(str, Enum): + GENERAL = "General Settings" + RISK = "Risk Management" + BUY = "Buy Order Settings" + SELL = "Sell Order Settings" + DCA = "DCA Settings" + INDICATORS = "Indicator Settings" + PROFITABILITY = "Profitability Settings" + EXECUTION = "Execution Settings" + ARBITRAGE = "Arbitrage Settings" + OTHER = "Other" + +class StrategyError(Exception): + """Base class for strategy-related errors""" + +class StrategyNotFoundError(StrategyError): + """Raised when a strategy cannot be found""" + +class StrategyValidationError(StrategyError): + """Raised when strategy parameters are invalid""" + +class ParameterConstraints(BaseModel): min_value: Optional[Union[int, float, Decimal]] = None max_value: Optional[Union[int, float, Decimal]] = None valid_values: Optional[List[Any]] = None - is_percentage: bool = False - is_price: bool = False - is_timespan: bool = False - is_connector: bool = False - is_trading_pair: bool = False - is_integer: bool = False - display_type: str = Field(default="input", description="Can be 'input', 'slider', 'dropdown', 'toggle', or 'date'") - -def is_advanced_parameter(name: str) -> bool: - advanced_keywords = [ - "activation_bounds", "triple_barrier", "leverage", "dca", "macd", "natr", - "multiplier", "imbalance", "executor", "perp", "arbitrage" - ] - - simple_keywords = [ - "controller_name", "candles", "interval", "stop_loss", "take_profit", - "buy", "sell", "position_size", "time_limit", "spot" - ] +class StrategyParameter(BaseModel): + # Core attributes + name: str + type: str + required: bool + default: Optional[Any] - name_lower = name.lower() + # Display attributes + display_name: str + description: str + group: ParameterGroup + is_advanced: bool - if any(keyword in name_lower for keyword in advanced_keywords): - return True + # Validation attributes + constraints: Optional[ParameterConstraints] = None - if any(keyword in name_lower for keyword in simple_keywords): - return False + # UI attributes + display_type: DisplayType = DisplayType.INPUT - return True + # Type flags (for backward compatibility and specific handling) + parameter_type: Optional[ParameterType] = None + +class Strategy(BaseModel): + id: str + name: str + description: str + type: StrategyType + module_path: str + config_class: str + parameters: Dict[str, StrategyParameter] + +class StrategyMapping(BaseModel): + """Maps a strategy ID to its implementation details""" + id: str # e.g., "supertrend_v1" + config_class: str # e.g., "supertrendconfig" + module_path: str # e.g., "bots.controllers.directional_trading.supertrend_v1" + strategy_type: StrategyType + display_name: str # e.g., "Supertrend V1" + description: str = "" + +class StrategyConfig(BaseModel): + """Complete strategy configuration including metadata and parameters""" + mapping: StrategyMapping + parameters: Dict[str, StrategyParameter] + +class ParameterSuggestionRequest(BaseModel): + strategy_id: str + parameters: Dict[str, Any] + requested_parameters: Optional[List[str]] = Field( + default=None, + description="Optional list of specific parameters to get suggestions for. If not provided, will suggest values for all missing required parameters." + ) + +class ParameterSuggestion(BaseModel): + parameter_name: str + suggested_value: str + reasoning: str + differs_from_default: bool = False + differs_from_provided: bool = False + +class ParameterSuggestionResponse(BaseModel): + suggestions: List[ParameterSuggestion] + summary: str def get_strategy_display_info() -> Dict[str, Dict[str, str]]: """ @@ -85,7 +152,7 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: "pretty_name": "SuperTrend Strategy", "description": "Follows market trends to find good times to buy and sell." }, - + # Market Making Strategies "dman_maker_v2": { "pretty_name": "Smart Market Maker", @@ -99,7 +166,7 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: "pretty_name": "Simple Market Maker", "description": "Places basic buy and sell orders with fixed spreads." }, - + # Generic Strategies "spot_perp_arbitrage": { "pretty_name": "Spot-Futures Arbitrage", @@ -111,105 +178,163 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: } } -def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParameter: - param = StrategyParameter( - name=name, - type=str(field.type_.__name__), - prompt=field.description if hasattr(field, 'description') else "", - default=field.default, - required=field.required or field.default is not None, - is_advanced=is_advanced_parameter(name), - group=determine_parameter_group(name), - pretty_name=name.replace('_', ' ').title(), - description="", - ) +class StrategyRegistry: + """Central registry for all trading strategies""" - # Get strategy display info - strategy_info = get_strategy_display_info() + _cache: Dict[str, Strategy] = {} - # Try to find matching strategy info - for strategy_name, info in strategy_info.items(): - if strategy_name in name.lower(): - param.pretty_name = info["pretty_name"] - param.description = info["description"] - break + @classmethod + def _ensure_cache_loaded(cls): + if not cls._cache: + cls._cache = discover_strategies() - # structure of field - client_data = field.field_info.extra.get('client_data') - if client_data is not None and param.prompt == "": - desc = client_data.prompt(None) if callable(client_data.prompt) else client_data.prompt - if desc is not None: - param.prompt = desc - if not param.required: - param.required = client_data.prompt_on_new if hasattr(client_data, 'prompt_on_new') else param.required - param.display_type = "input" + @classmethod + def get_all_strategies(cls) -> Dict[str, Strategy]: + """Get all available strategies with their configurations""" + cls._ensure_cache_loaded() + return cls._cache - # Check for gt (greater than) and lt (less than) in field definition - if hasattr(field.field_info, 'gt'): - param.min_value = field.field_info.gt - if hasattr(field.field_info, 'lt'): - param.max_value = field.field_info.lt + @classmethod + def get_strategy(cls, strategy_id: str) -> Optional[Strategy]: + """Get a specific strategy by ID""" + cls._ensure_cache_loaded() + strategy = cls._cache.get(strategy_id) + if not strategy: + raise StrategyNotFoundError(f"Strategy '{strategy_id}' not found") + return strategy - # Set display_type to "slider" only if both min and max values are present - if param.min_value is not None and param.max_value is not None: - param.display_type = "slider" - elif param.type == "bool": - param.display_type = "toggle" + @classmethod + def get_strategies_by_type(cls, strategy_type: StrategyType) -> List[Strategy]: + """Get all strategies of a specific type""" + cls._ensure_cache_loaded() + return [s for s in cls._cache.values() if s.type == strategy_type] + +def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParameter: + """Convert a model field to a strategy parameter""" + constraints = ParameterConstraints() + # Handle constraints + if hasattr(field.field_info, 'ge'): + constraints.min_value = field.field_info.ge + elif hasattr(field.field_info, 'gt'): + constraints.min_value = field.field_info.gt + (1 if isinstance(field.field_info.gt, int) else Decimal('0')) + + if hasattr(field.field_info, 'le'): + constraints.max_value = field.field_info.le + elif hasattr(field.field_info, 'lt'): + constraints.max_value = field.field_info.lt - (1 if isinstance(field.field_info.lt, int) else Decimal('0')) + + # Determine parameter type + param_type = None if "connector" in name.lower(): - param.is_connector = True - if "trading_pair" in name.lower(): - param.is_trading_pair = True - if any(word in name.lower() for word in ["percentage", "percent", "ratio", "pct"]): - param.is_percentage = True - if "price" in name.lower(): - param.is_price = True - if param.min_value is None: - param.min_value = Decimal(0) - if "amount" in name.lower(): - param.min_value = Decimal(0) - if any(word in name.lower() for word in ["time", "interval", "duration"]): - param.is_timespan = True - param.min_value = 0 - if param.type == "int": - param.is_integer = True - if any(word in name.lower() for word in ["executors", "workers"]): - param.display_type = "slider" - param.min_value = 1 - try: - if issubclass(field.type_, Enum): - param.valid_values = [item.value for item in field.type_] - param.display_type = "dropdown" - except: - pass - return param - -def determine_parameter_group(name: str) -> str: - if any(word in name.lower() for word in ["controller_name", "candles", "interval"]): - return "General Settings" - elif any(word in name.lower() for word in ["stop_loss", "trailing_stop", "take_profit", "activation_bounds", "leverage", "triple_barrier"]): - return "Risk Management" - elif "buy" in name.lower(): - return "Buy Order Settings" - elif "sell" in name.lower(): - return "Sell Order Settings" - elif "dca" in name.lower(): - return "DCA Settings" - elif any(word in name.lower() for word in ["bb", "macd", "natr", "length", "multiplier"]): - return "Indicator Settings" - elif any(word in name.lower() for word in ["profitability", "position_size"]): - return "Profitability Settings" - elif any(word in name.lower() for word in ["time_limit", "executor", "imbalance"]): - return "Execution Settings" - elif any(word in name.lower() for word in ["spot", "perp"]): - return "Arbitrage Settings" + param_type = ParameterType.CONNECTOR + elif "trading_pair" in name.lower(): + param_type = ParameterType.TRADING_PAIR + elif any(word in name.lower() for word in ["percentage", "percent", "ratio", "pct"]): + param_type = ParameterType.PERCENTAGE + elif "price" in name.lower(): + param_type = ParameterType.PRICE + elif any(word in name.lower() for word in ["time", "interval", "duration"]): + param_type = ParameterType.TIMESPAN + elif str(field.type_.__name__).lower() == "int": + param_type = ParameterType.INTEGER + elif str(field.type_.__name__).lower() == "decimal": + param_type = ParameterType.DECIMAL + elif str(field.type_.__name__).lower() == "bool": + param_type = ParameterType.BOOLEAN + else: + param_type = ParameterType.STRING + + # Determine display type + display_type = DisplayType.INPUT + if constraints.min_value is not None and constraints.max_value is not None: + display_type = DisplayType.SLIDER + elif param_type == ParameterType.BOOLEAN: + display_type = DisplayType.TOGGLE + elif constraints.valid_values: + display_type = DisplayType.DROPDOWN + + # Get group + group = determine_parameter_group(name) + + return StrategyParameter( + name=name, + type=str(field.type_.__name__), + required=field.required or field.default is not None, + default=field.default, + display_name=name.replace('_', ' ').title(), + description=field.description if hasattr(field, 'description') else "", + group=group, + is_advanced=is_advanced_parameter(name), + constraints=constraints, + display_type=display_type, + parameter_type=param_type + ) + +def determine_parameter_group(name: str) -> ParameterGroup: + """Determine the parameter group based on the parameter name""" + name_lower = name.lower() + + if any(word in name_lower for word in ["controller_name", "candles", "interval"]): + return ParameterGroup.GENERAL + elif any(word in name_lower for word in ["stop_loss", "trailing_stop", "take_profit", "activation_bounds", "leverage", "triple_barrier"]): + return ParameterGroup.RISK + elif "buy" in name_lower: + return ParameterGroup.BUY + elif "sell" in name_lower: + return ParameterGroup.SELL + elif "dca" in name_lower: + return ParameterGroup.DCA + elif any(word in name_lower for word in ["bb", "macd", "natr", "length", "multiplier"]): + return ParameterGroup.INDICATORS + elif any(word in name_lower for word in ["profitability", "position_size"]): + return ParameterGroup.PROFITABILITY + elif any(word in name_lower for word in ["time_limit", "executor", "imbalance"]): + return ParameterGroup.EXECUTION + elif any(word in name_lower for word in ["spot", "perp"]): + return ParameterGroup.ARBITRAGE else: - return "Other" + return ParameterGroup.OTHER +def snake_case_to_real_name(snake_case: str) -> str: + return " ".join([word.capitalize() for word in snake_case.split("_")]) + +def infer_strategy_type(module_path: str, config_class: Any) -> StrategyType: + """Infer the strategy type from the module path and config class""" + if "directional_trading" in module_path: + return StrategyType.DIRECTIONAL_TRADING + elif "market_making" in module_path: + return StrategyType.MARKET_MAKING + else: + return StrategyType.GENERIC + +def generate_strategy_mapping(module_path: str, config_class: Any) -> StrategyMapping: + """Generate a strategy mapping from a config class""" + # Extract strategy ID from module path (e.g., "supertrend_v1" from "bots.controllers.directional_trading.supertrend_v1") + strategy_id = module_path.split(".")[-1] + + # Get strategy type + strategy_type = infer_strategy_type(module_path, config_class) + + # Generate display name + display_name = " ".join(word.capitalize() for word in strategy_id.split("_")) + + # Get description from class docstring + description = config_class.__doc__ or "" + + return StrategyMapping( + id=strategy_id, + config_class=config_class.__name__, + module_path=module_path, + strategy_type=strategy_type, + display_name=display_name, + description=description + ) @functools.lru_cache(maxsize=1) -def get_all_strategy_maps() -> Dict[str, Dict[str, StrategyParameter]]: - strategy_maps = {} +def discover_strategies() -> Dict[str, Strategy]: + """Discover and load all available strategies""" + strategies = {} controllers_dir = "bots/controllers" for root, dirs, files in os.walk(controllers_dir): @@ -230,17 +355,87 @@ def get_all_strategy_maps() -> Dict[str, Dict[str, StrategyParameter]]: and obj is not DirectionalTradingControllerConfigBase ): assert isinstance(obj, ModelMetaclass) - strategy_name = obj.controller_name if hasattr(obj, 'controller_name') else name.lower() + + # Extract strategy ID from module path + strategy_id = module_path.split(".")[-1] + + # Get strategy type + strategy_type = infer_strategy_type(module_path, obj) + + # Get display info + display_info = get_strategy_display_info().get(strategy_id, {}) + + # Convert parameters parameters = {} for field_name, field in obj.__fields__.items(): param = convert_to_strategy_parameter(field_name, field) parameters[field_name] = param - strategy_maps[strategy_name] = parameters + # Create strategy + strategies[strategy_id] = Strategy( + id=strategy_id, + name=display_info.get("pretty_name", " ".join(word.capitalize() for word in strategy_id.split("_"))), + description=display_info.get("description", obj.__doc__ or ""), + type=strategy_type, + module_path=module_path, + config_class=obj.__name__, + parameters=parameters + ) + except ImportError as e: - print(f"Error importing module {module_path}: {e}") + logger.error(f"Error importing module {module_path}: {e}") except Exception as e: - print(f"Unexpected error processing {module_path}: {e}") + logger.error(f"Unexpected error processing {module_path}: {e}") import traceback traceback.print_exc() - return strategy_maps + + return strategies + +def get_strategy_mapping(strategy_id: str) -> Optional[StrategyMapping]: + """ + DEPRECATED: Use StrategyRegistry.get_strategy() instead + Get strategy mapping by ID + """ + logger.warning("get_strategy_mapping is deprecated. Use StrategyRegistry.get_strategy() instead") + strategy = StrategyRegistry.get_strategy(strategy_id) + if not strategy: + return None + return StrategyMapping( + id=strategy.id, + config_class=strategy.config_class, + module_path=strategy.module_path, + strategy_type=strategy.type, + display_name=strategy.name, + description=strategy.description + ) + +def get_strategy_module_path(strategy_id: str) -> Optional[str]: + """ + DEPRECATED: Use StrategyRegistry.get_strategy().module_path instead + Get the module path for a strategy + """ + logger.warning("get_strategy_module_path is deprecated. Use StrategyRegistry.get_strategy().module_path instead") + strategy = StrategyRegistry.get_strategy(strategy_id) + return strategy.module_path if strategy else None + +def is_advanced_parameter(name: str) -> bool: + """Determine if a parameter should be considered advanced""" + advanced_keywords = [ + "activation_bounds", "triple_barrier", "leverage", "dca", "macd", "natr", + "multiplier", "imbalance", "executor", "perp", "arbitrage" + ] + + simple_keywords = [ + "controller_name", "candles", "interval", "stop_loss", "take_profit", + "buy", "sell", "position_size", "time_limit", "spot" + ] + + name_lower = name.lower() + + if any(keyword in name_lower for keyword in advanced_keywords): + return True + + if any(keyword in name_lower for keyword in simple_keywords): + return False + + return True diff --git a/services/__init__.py b/services/__init__.py index e69de29..9c0bf67 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -0,0 +1,3 @@ +""" +Services package initialization. +""" diff --git a/services/libert_ai_service.py b/services/libert_ai_service.py new file mode 100644 index 0000000..cbe2a8d --- /dev/null +++ b/services/libert_ai_service.py @@ -0,0 +1,421 @@ +import json +import logging +import aiohttp +import inspect +import importlib +from typing import Dict, Any, List, Optional +from routers.strategies_models import ( + ParameterSuggestion, + StrategyConfig, + StrategyMapping, + discover_strategies +) + +logger = logging.getLogger(__name__) + +class LibertAIService: + def __init__(self): + # Hermes 2 pro + self.api_url = "https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion" + + self.strategy_slot_map: Dict[str, int] = {} # Maps strategy IDs to their slot IDs + self.next_slot_id = 0 + + async def initialize_contexts(self, strategies: Dict[str, StrategyConfig]): + """Initialize context slots for system prompt and each strategy.""" + try: + logger.info("Starting context initialization...") + + # Initialize system prompt in slot -1 + logger.info("Initializing system context...") + await self._initialize_system_context() + + # Initialize each strategy's context + for strategy_id, strategy_config in strategies.items(): + logger.info(f"Initializing context for strategy: {strategy_id}") + slot_id = self.next_slot_id + self.strategy_slot_map[strategy_id] = slot_id + self.next_slot_id += 1 + + # Load strategy implementation code + strategy_code = await self._load_strategy_code(strategy_config.mapping) + logger.info(f"Loaded strategy code for {strategy_id}, code length: {len(strategy_code)}") + + await self._initialize_strategy_context( + strategy_mapping=strategy_config.mapping, + strategy_config=strategy_config.parameters, + strategy_code=strategy_code, + slot_id=slot_id + ) + + logger.info(f"Context initialization complete. Strategy slot map: {self.strategy_slot_map}") + + except Exception as e: + logger.error(f"Error initializing contexts: {str(e)}") + raise + + async def _load_strategy_code(self, mapping: StrategyMapping) -> str: + """Load the strategy implementation code using the strategy mapping.""" + try: + # Import the module using the mapping's module path + module = importlib.import_module(mapping.module_path) + + # Get all classes in the module + strategy_classes = inspect.getmembers( + module, + lambda member: ( + inspect.isclass(member) + and member.__module__ == module.__name__ + and not member.__name__.endswith('Config') + ) + ) + + if not strategy_classes: + raise ValueError(f"No strategy class found in {mapping.module_path}") + + # Get the source code of the strategy class + strategy_class = strategy_classes[0][1] # Take the first class + source_code = inspect.getsource(strategy_class) + + return source_code + + except Exception as e: + logger.error(f"Error loading strategy code for {mapping.id}: {str(e)}") + return f"# Strategy implementation code not found for {mapping.id}" + + async def _initialize_system_context(self): + """Initialize the system prompt in slot -1.""" + system_prompt = """<|im_start|>system +You are an expert trading strategy advisor. Your task is to analyze trading strategies and suggest optimal parameter values. + +For each parameter suggestion, you must: +1. Suggest an appropriate value based on the strategy type and configuration +2. Provide a clear explanation of why this value would be appropriate +3. Consider potential risks and market conditions +4. Take into account the strategy's implementation code and logic + +Format each suggestion exactly as follows: +PARAMETER: [parameter_name] +VALUE: [suggested_value] +REASONING: [detailed explanation] +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + await session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json={ + "prompt": system_prompt, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 100, + "stop": ["<|im_end|>"] + } + ) + except Exception as e: + print(f"ERROR: Error initializing system context: {str(e)}") + raise + + async def _initialize_strategy_context( + self, + strategy_mapping: StrategyMapping, + strategy_config: Dict[str, Any], + strategy_code: str, + slot_id: int + ): + """Initialize context for a specific strategy.""" + # Convert strategy parameters to a serializable format + serializable_config = { + name: { + "name": param.name, + "group": param.group, + "type": param.type, + "prompt": param.prompt, + "default": str(param.default) if param.default is not None else None, + "required": param.required, + "min_value": str(param.min_value) if param.min_value is not None else None, + "max_value": str(param.max_value) if param.max_value is not None else None, + "is_advanced": param.is_advanced, + "display_type": param.display_type + } + for name, param in strategy_config.items() + } + + strategy_context = f"""<|im_start|>user +Trading Strategy: {strategy_mapping.display_name} +Type: {strategy_mapping.strategy_type.value} +Description: {strategy_mapping.description} + +Strategy Configuration Schema: +{json.dumps(serializable_config, indent=2)} + +Strategy Implementation: +```python +{strategy_code} +``` + +This strategy's configuration defines the parameters and their constraints, while the implementation shows how these parameters are used in the trading logic. Use both to make informed suggestions about parameter values. +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + await session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json={ + "prompt": strategy_context, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 100, + "stop": ["<|im_end|>"], + "slot_id": slot_id, + "parent_slot_id": -1, + } + ) + except Exception as e: + print(f"ERROR: Error initializing strategy context for {strategy_mapping.id}: {str(e)}") + raise + + async def get_parameter_suggestions( + self, + strategy_id: str, + strategy_config: Dict[str, Any], + provided_params: Dict[str, Any], + requested_params: Optional[List[str]] = None + ) -> List[ParameterSuggestion]: + """Get parameter suggestions from LibertAI. + + Args: + strategy_id: ID of the strategy + strategy_config: Full strategy configuration + provided_params: Parameters already provided by the user + requested_params: Optional list of specific parameters to get suggestions for + """ + print("\n=== Getting Parameter Suggestions ===") + print(f"Strategy ID: {strategy_id}") + print(f"Provided parameters: {json.dumps(provided_params, indent=2)}") + print(f"Requested parameters: {requested_params}") + + # Get strategy configuration + strategies = discover_strategies() + strategy = strategies.get(strategy_id) + if not strategy: + print(f"ERROR: No strategy found with ID {strategy_id}") + return [] + + # Identify missing required parameters and optional parameters + missing_required = [] + optional_params = [] + + # If specific parameters are requested, only consider those + params_to_check = requested_params if requested_params else strategy_config.keys() + + for param_name in params_to_check: + if param_name not in provided_params: + param_config = strategy_config.get(param_name) + if param_config: + if param_config.required: + missing_required.append(param_name) + else: + optional_params.append(param_name) + + print(f"Missing required parameters: {missing_required}") + print(f"Optional parameters: {optional_params}") + + # Get the strategy's slot ID + slot_id = self.strategy_slot_map.get(strategy_id) + if slot_id is None: + print(f"ERROR: No cached context found for strategy {strategy_id}") + return [] + + # Convert parameters to a serializable format + serializable_params = { + name: str(value) if hasattr(value, "__str__") else value + for name, value in provided_params.items() + } + + # Update the prompt to be more explicit about the format and requested parameters + optional_params_text = f"Optional Parameters That Could Be Set:\n{', '.join(optional_params) if optional_params else 'None'}" if not requested_params else "" + + request_prompt = f"""<|im_start|>user +Strategy: {strategy.mapping.display_name} +Type: {strategy.mapping.strategy_type.value} + +Currently Provided Parameters: +{json.dumps(serializable_params, indent=2)} + +{"Parameters to Suggest:" if requested_params else "Missing Required Parameters:"} +{', '.join(requested_params) if requested_params else ', '.join(missing_required) if missing_required else 'None'} + +{optional_params_text} + +Please suggest optimal values for {"the requested" if requested_params else "the missing"} parameters using exactly this format for each parameter: + +PARAMETER: [parameter_name] +VALUE: [suggested_value] +REASONING: [detailed explanation of why this value is appropriate] + +End with a summary: +SUMMARY: [overall explanation of the suggested configuration] + +Do not include code blocks or other formats. Use only the PARAMETER/VALUE/REASONING structure. +<|im_end|>""" + + try: + async with aiohttp.ClientSession() as session: + print(f"\nSending request to LibertAI API...") + print(f"Request prompt:\n{request_prompt}") + + request_payload = { + "slot_id": self.next_slot_id, + "parent_slot_id": slot_id, + "prompt": request_prompt, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 1500, + "stop": ["<|im_end|>"] + } + + async with session.post( + self.api_url, + headers={"Content-Type": "application/json"}, + json=request_payload + ) as response: + if response.status != 200: + print(f"ERROR: API returned status {response.status}") + response_text = await response.text() + print(f"Response body: {response_text}") + return [] + + result = await response.json() + print(f"\nReceived response from API: {json.dumps(result, indent=2)}") + return self._parse_ai_response( + {"choices": [{"message": {"content": result["content"]}}]}, + strategy_config=strategy_config, + provided_params=provided_params + ) + + except Exception as e: + print(f"ERROR: Exception during API call: {str(e)}") + return [] + + def _parse_ai_response(self, ai_response: Dict[str, Any], strategy_config: Dict[str, Any], provided_params: Dict[str, Any]) -> List[ParameterSuggestion]: + print("\n=== Parsing AI Response ===") + try: + content = ai_response["choices"][0]["message"]["content"] + print(f"Response content preview: {content[:200]}...") + + suggestions = [] + seen_params = set() + summary = None + + # Create a map of default values and provided values for comparison + default_values = { + name: str(param.default) if param.default is not None else None + for name, param in strategy_config.items() + } + + provided_values = { + name: str(value) if hasattr(value, "__str__") else str(value) + for name, value in provided_params.items() + } + + if "PARAMETER:" in content: + print("Found structured format with PARAMETER/VALUE/REASONING") + parameter_sections = content.split("PARAMETER:") + + for section in parameter_sections[1:]: + lines = section.strip().split("\n") + param_name = lines[0].strip() + + # Initialize collectors for multi-line values + value_lines = [] + reasoning_lines = [] + collecting_value = False + collecting_reasoning = False + + # Process remaining lines + for line in lines[1:]: + line = line.strip() + + if line.startswith("VALUE:"): + collecting_value = True + collecting_reasoning = False + value_lines.append(line.replace("VALUE:", "").strip()) + elif line.startswith("REASONING:"): + collecting_value = False + collecting_reasoning = True + reasoning_lines.append(line.replace("REASONING:", "").strip()) + elif line.startswith("SUMMARY:"): + collecting_value = False + collecting_reasoning = False + summary = line.replace("SUMMARY:", "").strip() + else: + # Continue collecting multi-line values + if collecting_value and line: + value_lines.append(line) + elif collecting_reasoning and line: + reasoning_lines.append(line) + + # Process collected values + if param_name and value_lines and param_name not in seen_params: + seen_params.add(param_name) + + # Join multi-line values and try to parse as JSON if it looks like a JSON structure + value = "\n".join(value_lines) + if value.strip().startswith("{") and value.strip().endswith("}"): + try: + parsed_value = json.loads(value) + value = json.dumps(parsed_value) + except json.JSONDecodeError: + pass + + # Compare with default and provided values + differs_from_default = ( + param_name in default_values and + default_values[param_name] is not None and + value != default_values[param_name] + ) + differs_from_provided = ( + param_name in provided_values and + value != provided_values[param_name] + ) + + suggestions.append(ParameterSuggestion( + parameter_name=param_name, + suggested_value=value, + reasoning="\n".join(reasoning_lines) if reasoning_lines else "No reasoning provided", + differs_from_default=differs_from_default, + differs_from_provided=differs_from_provided + )) + + if summary: + suggestions.append(ParameterSuggestion( + parameter_name="summary", + suggested_value=summary, + reasoning="Summary of the suggested configuration", + differs_from_default=False, + differs_from_provided=False + )) + + print(f"\nTotal suggestions parsed: {len(suggestions)}") + for s in suggestions: + print(f"- {s.parameter_name}: {s.suggested_value}") + if s.differs_from_default: + print(f" (differs from default: {s.differs_from_default})") + if s.differs_from_provided: + print(f" (differs from provided: {s.differs_from_provided})") + + return suggestions + + except Exception as e: + print(f"ERROR: Failed to parse AI response: {str(e)}") + print(f"Raw response: {json.dumps(ai_response, indent=2)}") + return [] \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9d55c93 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests package initialization. +""" \ No newline at end of file diff --git a/tests/test_libert_ai_service.py b/tests/test_libert_ai_service.py new file mode 100644 index 0000000..b8c38c6 --- /dev/null +++ b/tests/test_libert_ai_service.py @@ -0,0 +1,164 @@ +import pytest +from services.libert_ai_service import LibertAIService +from routers.strategies_models import ( + ParameterSuggestion, + discover_strategies, +) +from typing import Any +from dataclasses import dataclass + +@pytest.fixture +def libert_ai_service(): + service = LibertAIService() + return service + +@pytest.fixture +def strategy_configs(): + """Load all available strategies""" + return discover_strategies() + +@pytest.fixture +def bollinger_strategy(strategy_configs): + """Get the Bollinger strategy configuration""" + return strategy_configs["bollinger_v1"] + +@pytest.mark.asyncio +async def test_initialize_system_context(libert_ai_service): + """Test system context initialization""" + await libert_ai_service._initialize_system_context() + +@pytest.mark.asyncio +async def test_initialize_strategy_context(libert_ai_service, bollinger_strategy): + """Test strategy context initialization""" + with open(f"bots/{bollinger_strategy.mapping.module_path.split('bots.')[-1].replace('.', '/')}.py", "r") as f: + strategy_code = f.read() + + await libert_ai_service._initialize_strategy_context( + strategy_mapping=bollinger_strategy.mapping, + strategy_config=bollinger_strategy.parameters, + strategy_code=strategy_code, + slot_id=0 + ) + +@pytest.mark.asyncio +async def test_get_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Test parameter suggestion generation""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0} + ) + + # Verify suggestions are part of the bollinger_strategy.parameters + for suggestion in suggestions: + assert suggestion.parameter_name in bollinger_strategy.parameters or suggestion.parameter_name == "summary" + +@pytest.mark.asyncio +async def test_parse_ai_response(libert_ai_service): + """Test AI response parsing""" + ai_response = { + "choices": [{ + "message": { + "content": """ +PARAMETER: bb_length +VALUE: 100 +REASONING: Standard length for Bollinger Bands calculation provides reliable signals while filtering out noise. + +PARAMETER: bb_long_threshold +VALUE: 0.2 +REASONING: Enter long positions when price is 20% below the middle band, indicating oversold conditions. + +SUMMARY: These parameters are optimized for mean reversion trading using Bollinger Bands. +""" + } + }] + } + + # Mock strategy config with proper parameter objects + @dataclass + class MockParameter: + name: str + default: Any + required: bool = True + type: str = "float" + + strategy_config = { + "bb_length": MockParameter( + name="BB Length", + default=20, + type="int" + ), + "bb_long_threshold": MockParameter( + name="BB Long Threshold", + default=0.1, + type="float" + ) + } + + provided_params = { + "bb_length": 20 + } + + suggestions = libert_ai_service._parse_ai_response( + ai_response, + strategy_config=strategy_config, + provided_params=provided_params + ) + + assert len(suggestions) == 3 # 2 parameters + summary + assert all(isinstance(s, ParameterSuggestion) for s in suggestions) + assert suggestions[0].parameter_name == "bb_length" + assert suggestions[0].suggested_value == "100" + assert suggestions[0].differs_from_default is True # 100 vs default 20 + assert suggestions[0].differs_from_provided is True # 100 vs provided 20 + assert suggestions[1].parameter_name == "bb_long_threshold" + assert suggestions[1].suggested_value == "0.2" + assert suggestions[1].differs_from_default is True # 0.2 vs default 0.1 + assert suggestions[1].differs_from_provided is False # Not provided + assert suggestions[2].parameter_name == "summary" + assert suggestions[2].suggested_value == "These parameters are optimized for mean reversion trading using Bollinger Bands." + +@pytest.mark.asyncio +async def test_parse_ai_response_handles_invalid_format(libert_ai_service): + """Test handling of invalid AI response format""" + invalid_response = { + "choices": [{ + "message": { + "content": "Invalid format response" + } + }] + } + + # Mock empty config with proper parameter objects + strategy_config = {} + provided_params = {} + + suggestions = libert_ai_service._parse_ai_response( + invalid_response, + strategy_config=strategy_config, + provided_params=provided_params + ) + assert suggestions == [] + +@pytest.mark.asyncio +async def test_get_specific_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Test getting suggestions for specific parameters""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + # Request suggestions for specific parameters + requested_params = ["bb_length", "bb_long_threshold"] + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0}, + requested_params=requested_params + ) + + # Verify we only got suggestions for the requested parameters (plus summary) + assert len(suggestions) == 3 # 2 requested parameters + summary + suggestion_params = {s.parameter_name for s in suggestions if s.parameter_name != "summary"} + assert suggestion_params == set(requested_params) \ No newline at end of file diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..91f3b16 --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,234 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from decimal import Decimal +from typing import Dict, Any +from pydantic import BaseModel, Field +from hummingbot.strategy_v2.controllers import ControllerConfigBase + +from routers.strategies_models import ( + StrategyType, + StrategyMapping, + StrategyParameter, + StrategyConfig, + discover_strategies, + generate_strategy_mapping, + convert_to_strategy_parameter, + infer_strategy_type +) + +# Mock strategy config class for testing +class MockStrategyConfig(ControllerConfigBase): + """Test strategy for unit testing""" + controller_name = "test_strategy_v1" + + stop_loss: Decimal = Field( + default=Decimal("0.03"), + description="Stop loss percentage", + ge=Decimal("0"), + le=Decimal("1") + ) + take_profit: Decimal = Field( + default=Decimal("0.02"), + description="Take profit percentage", + ge=Decimal("0"), + le=Decimal("1") + ) + time_limit: int = Field( + default=2700, + description="Time limit in seconds", + gt=0 + ) + leverage: int = Field( + default=20, + description="Leverage multiplier", + gt=0 + ) + trading_pair: str = Field( + default="BTC-USDT", + description="Trading pair to use" + ) + +# Test data +MOCK_MODULE_PATH = "bots.controllers.directional_trading.test_strategy_v1" + +@pytest.fixture +def mock_strategy_config(): + return MockStrategyConfig + +@pytest.fixture(autouse=True) +def mock_importlib(): + with patch("importlib.import_module") as mock: + mock.return_value = MagicMock( + __name__="test_module", + MockStrategyConfig=MockStrategyConfig + ) + yield mock + +@pytest.fixture(autouse=True) +def mock_os_walk(): + with patch("os.walk") as mock: + mock.return_value = [ + ("bots/controllers/directional_trading", [], ["test_strategy_v1.py"]), + ] + yield mock + +@pytest.fixture(autouse=True) +def mock_discover_strategies(): + """Mock discover_strategies to return our test data""" + with patch("routers.strategies_models.discover_strategies", autospec=True) as mock: + mock.return_value = { + "test_strategy_v1": StrategyConfig( + mapping=StrategyMapping( + id="test_strategy_v1", + config_class="MockStrategyConfig", + module_path=MOCK_MODULE_PATH, + strategy_type=StrategyType.DIRECTIONAL_TRADING, + display_name="Test Strategy V1", + description="Test strategy for unit testing" + ), + parameters={ + "stop_loss": StrategyParameter( + name="stop_loss", + pretty_name="Stop Loss", + description="Stop loss percentage", + group="Risk Management", + type="Decimal", + prompt="Enter stop loss value", + default=Decimal("0.03"), + required=True, + min_value=Decimal("0"), + max_value=Decimal("1") + ), + "take_profit": StrategyParameter( + name="take_profit", + pretty_name="Take Profit", + description="Take profit percentage", + group="Risk Management", + type="Decimal", + prompt="Enter take profit value", + default=Decimal("0.02"), + required=True, + min_value=Decimal("0"), + max_value=Decimal("1") + ), + "time_limit": StrategyParameter( + name="time_limit", + pretty_name="Time Limit", + description="Time limit in seconds", + group="General Settings", + type="int", + prompt="Enter time limit in seconds", + default=2700, + required=True, + min_value=0 + ), + "leverage": StrategyParameter( + name="leverage", + pretty_name="Leverage", + description="Leverage multiplier", + group="Risk Management", + type="int", + prompt="Enter leverage multiplier", + default=20, + required=True, + min_value=1, + is_advanced=True + ), + "trading_pair": StrategyParameter( + name="trading_pair", + pretty_name="Trading Pair", + description="Trading pair to use", + group="General Settings", + type="str", + prompt="Enter trading pair", + default="BTC-USDT", + required=True, + is_trading_pair=True + ) + } + ) + } + yield mock + +def test_infer_strategy_type(): + """Test strategy type inference from module path""" + assert infer_strategy_type("bots.controllers.directional_trading.test", None) == StrategyType.DIRECTIONAL_TRADING + assert infer_strategy_type("bots.controllers.market_making.test", None) == StrategyType.MARKET_MAKING + assert infer_strategy_type("bots.controllers.generic.test", None) == StrategyType.GENERIC + +def test_generate_strategy_mapping(): + """Test strategy mapping generation""" + mapping = generate_strategy_mapping(MOCK_MODULE_PATH, MockStrategyConfig) + + assert mapping.id == "test_strategy_v1" + assert mapping.config_class == "MockStrategyConfig" + assert mapping.module_path == MOCK_MODULE_PATH + assert mapping.strategy_type == StrategyType.DIRECTIONAL_TRADING + assert mapping.display_name == "Test Strategy V1" + assert "Test strategy for unit testing" in mapping.description + +def test_convert_to_strategy_parameter(): + """Test parameter conversion from config field""" + # Get a field from the mock config + field = MockStrategyConfig.__fields__["stop_loss"] + param = convert_to_strategy_parameter("stop_loss", field) + + assert param.pretty_name == "Stop Loss" + assert param.group == "Risk Management" + assert param.type == "ConstrainedDecimalValue" # We want the base type, not the constrained type + assert param.default == Decimal("0.03") + assert param.required is True + assert param.min_value == Decimal("0") + assert param.max_value == Decimal("1") + assert param.display_type == "slider" + +@pytest.mark.asyncio +async def test_discover_strategies(): + """Test strategy auto-discovery""" + strategies = discover_strategies() + + assert len(strategies) == 9 + assert "bollinger_v1" in strategies + + strategy = strategies["bollinger_v1"] + assert isinstance(strategy, StrategyConfig) + assert strategy.mapping.id == "bollinger_v1" + + # Check some parameters + assert "stop_loss" in strategy.parameters + assert "take_profit" in strategy.parameters + assert strategy.parameters["leverage"].is_advanced is True + assert strategy.parameters["trading_pair"].is_trading_pair is True + + +def test_parameter_validation(): + """Test parameter validation in strategy config""" + # Test required parameters + with pytest.raises(ValueError): + MockStrategyConfig( + stop_loss=None, # Required parameter missing + take_profit=Decimal("0.02"), + time_limit=2700, + leverage=20, + trading_pair="BTC-USDT" + ) + + # Test parameter constraints + with pytest.raises(ValueError): + MockStrategyConfig( + stop_loss=Decimal("-0.03"), # Negative value not allowed + take_profit=Decimal("0.02"), + time_limit=2700, + leverage=20, + trading_pair="BTC-USDT" + ) + +def test_strategy_type_enum(): + """Test StrategyType enum values""" + assert StrategyType.DIRECTIONAL_TRADING == "directional_trading" + assert StrategyType.MARKET_MAKING == "market_making" + assert StrategyType.GENERIC == "generic" + + # Test that invalid types are not allowed + with pytest.raises(ValueError): + StrategyType("invalid_type") \ No newline at end of file