Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 52 additions & 36 deletions runtime/launch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing_extensions import Literal
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, model_validator
from typing import Optional, Dict, Any, Union
from dotenv import load_dotenv
import traceback
Expand Down Expand Up @@ -34,11 +34,6 @@

import uvicorn

# from cerebrum.llm.layer import LLMLayer as LLMConfig
# from cerebrum.memory.layer import MemoryLayer as MemoryConfig
# from cerebrum.storage.layer import StorageLayer as StorageConfig
# from cerebrum.tool.layer import ToolLayer as ToolManagerConfig

load_dotenv()

app = FastAPI()
Expand All @@ -52,6 +47,7 @@
)

# Store component configurations and instances
global active_components
active_components = {
"llm": None,
"storage": None,
Expand Down Expand Up @@ -113,28 +109,31 @@ class QueryRequest(BaseModel):
query_type: Literal["llm", "tool", "storage", "memory"]
query_data: LLMQuery | ToolQuery | StorageQuery | MemoryQuery

@root_validator(pre=True)
def convert_query_data(cls, values: dict[str, Any]) -> dict[str, Any]:
if 'query_type' not in values or 'query_data' not in values:
return values

query_type = values['query_type']
query_data = values['query_data']

type_mapping = {
'llm': LLMQuery,
'tool': ToolQuery,
'storage': StorageQuery,
'memory': MemoryQuery
}

if isinstance(query_data, type_mapping[query_type]):
return values

if isinstance(query_data, dict):
values['query_data'] = type_mapping[query_type](**query_data)

return values
@model_validator(mode='before')
def convert_query_data(cls, data: Any) -> Any:
if isinstance(data, dict):
query_type = data.get('query_type')
query_data = data.get('query_data')

if not query_type or not query_data:
return data

type_mapping = {
'llm': LLMQuery,
'tool': ToolQuery,
'storage': StorageQuery,
'memory': MemoryQuery
}

target_type = type_mapping.get(query_type)
if target_type and isinstance(query_data, dict) and not isinstance(query_data, target_type):
try:
data['query_data'] = target_type(**query_data)
except Exception as e:
# Handle potential validation errors if needed
# For now, just pass the original data
pass
return data

def initialize_llm_cores(config: dict) -> Any:
"""Initialize LLM core with configuration."""
Expand Down Expand Up @@ -331,8 +330,8 @@ def restart_kernel():

# Initialize new components
active_components = initialize_components()
# if not initialize_components():
# raise Exception("Failed to initialize components")
if not active_components:
raise Exception("Failed to initialize components")

print("✅ All components reinitialized successfully")

Expand All @@ -345,11 +344,11 @@ def restart_kernel():
async def refresh_configuration():
"""Refresh all component configurations"""
try:
print("Received refresh request")
logger.info("Received refresh request")
config.refresh()
print("Configuration reloaded")
logger.info("Configuration reloaded")
restart_kernel()
print("Kernel restarted")
logger.info("Kernel restarted")
return {
"status": "success",
"message": "Configuration refreshed and components reinitialized"
Expand Down Expand Up @@ -472,12 +471,12 @@ async def submit_agent(config: AgentSubmit):
config=config.agent_config
)


return {
"status": "success",
"execution_id": execution_id,
"message": f"Agent {config.agent_id} submitted for execution"
}

except Exception as e:
error_msg = str(e)
stack_trace = traceback.format_exc()
Expand Down Expand Up @@ -609,14 +608,31 @@ async def update_config(request: Request):
data = await request.json()
logger.info(f"Received config update request: {data}")

name = data.get("name")
provider = data.get("provider")
api_key = data.get("api_key")

if not all([provider, api_key]):
raise ValueError("Missing required fields: provider, api_key")

# Update configuration
config.config["api_keys"][provider] = api_key
config.config["llms"]["models"] = [
{
"name": name,
"backend": provider,
# API key is better stored in api_keys section
}
]

# Update API keys section if api_key is provided
if api_key:
if "api_keys" not in config.config:
config.config["api_keys"] = {}
config.config["api_keys"][provider] = api_key # Use backend name as the key


else:
raise ValueError("Missing required fields: api_key")

config.save_config()

# Try to reinitialize LLM component
Expand Down