From 7b854a5177a20eac20a9a08eb591a233155a9043 Mon Sep 17 00:00:00 2001 From: dongyuanjushi Date: Fri, 18 Apr 2025 15:05:57 -0400 Subject: [PATCH] fix configuration refresh --- runtime/launch.py | 88 ++++++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/runtime/launch.py b/runtime/launch.py index a6ab59f8..6ee132a8 100644 --- a/runtime/launch.py +++ b/runtime/launch.py @@ -1,6 +1,6 @@ from typing_extensions import Literal from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, model_validator from typing import Optional, Dict, Any, Union from dotenv import load_dotenv import traceback @@ -34,11 +34,6 @@ import uvicorn -# from cerebrum.llm.layer import LLMLayer as LLMConfig -# from cerebrum.memory.layer import MemoryLayer as MemoryConfig -# from cerebrum.storage.layer import StorageLayer as StorageConfig -# from cerebrum.tool.layer import ToolLayer as ToolManagerConfig - load_dotenv() app = FastAPI() @@ -52,6 +47,7 @@ ) # Store component configurations and instances +global active_components active_components = { "llm": None, "storage": None, @@ -113,28 +109,31 @@ class QueryRequest(BaseModel): query_type: Literal["llm", "tool", "storage", "memory"] query_data: LLMQuery | ToolQuery | StorageQuery | MemoryQuery - @root_validator(pre=True) - def convert_query_data(cls, values: dict[str, Any]) -> dict[str, Any]: - if 'query_type' not in values or 'query_data' not in values: - return values - - query_type = values['query_type'] - query_data = values['query_data'] - - type_mapping = { - 'llm': LLMQuery, - 'tool': ToolQuery, - 'storage': StorageQuery, - 'memory': MemoryQuery - } - - if isinstance(query_data, type_mapping[query_type]): - return values - - if isinstance(query_data, dict): - values['query_data'] = type_mapping[query_type](**query_data) - - return values + @model_validator(mode='before') + def convert_query_data(cls, data: Any) -> Any: + if isinstance(data, dict): + query_type = data.get('query_type') + query_data = data.get('query_data') + + if not query_type or not query_data: + return data + + type_mapping = { + 'llm': LLMQuery, + 'tool': ToolQuery, + 'storage': StorageQuery, + 'memory': MemoryQuery + } + + target_type = type_mapping.get(query_type) + if target_type and isinstance(query_data, dict) and not isinstance(query_data, target_type): + try: + data['query_data'] = target_type(**query_data) + except Exception as e: + # Handle potential validation errors if needed + # For now, just pass the original data + pass + return data def initialize_llm_cores(config: dict) -> Any: """Initialize LLM core with configuration.""" @@ -331,8 +330,8 @@ def restart_kernel(): # Initialize new components active_components = initialize_components() - # if not initialize_components(): - # raise Exception("Failed to initialize components") + if not active_components: + raise Exception("Failed to initialize components") print("✅ All components reinitialized successfully") @@ -345,11 +344,11 @@ def restart_kernel(): async def refresh_configuration(): """Refresh all component configurations""" try: - print("Received refresh request") + logger.info("Received refresh request") config.refresh() - print("Configuration reloaded") + logger.info("Configuration reloaded") restart_kernel() - print("Kernel restarted") + logger.info("Kernel restarted") return { "status": "success", "message": "Configuration refreshed and components reinitialized" @@ -472,12 +471,12 @@ async def submit_agent(config: AgentSubmit): config=config.agent_config ) - return { "status": "success", "execution_id": execution_id, "message": f"Agent {config.agent_id} submitted for execution" } + except Exception as e: error_msg = str(e) stack_trace = traceback.format_exc() @@ -609,14 +608,31 @@ async def update_config(request: Request): data = await request.json() logger.info(f"Received config update request: {data}") + name = data.get("name") provider = data.get("provider") api_key = data.get("api_key") if not all([provider, api_key]): raise ValueError("Missing required fields: provider, api_key") - # Update configuration - config.config["api_keys"][provider] = api_key + config.config["llms"]["models"] = [ + { + "name": name, + "backend": provider, + # API key is better stored in api_keys section + } + ] + + # Update API keys section if api_key is provided + if api_key: + if "api_keys" not in config.config: + config.config["api_keys"] = {} + config.config["api_keys"][provider] = api_key # Use backend name as the key + + + else: + raise ValueError("Missing required fields: api_key") + config.save_config() # Try to reinitialize LLM component