Skip to content

Commit ad78305

Browse files
committed
feat: Integrate OpenRouter API key management and model fetching into settings
1 parent 828ec44 commit ad78305

File tree

9 files changed

+995
-255
lines changed

9 files changed

+995
-255
lines changed

backend/agent_manager.py

Lines changed: 104 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Dict, Any, List, Tuple, Optional, Union, TypedDict
44
from dotenv import load_dotenv
55
from langchain_google_genai import ChatGoogleGenerativeAI
6+
from langchain_openai import ChatOpenAI
7+
from langchain_core.language_models.chat_models import BaseChatModel
68
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
79
from langchain.schema import StrOutputParser
810
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
@@ -207,8 +209,8 @@ class ChapterGenerationState(TypedDict):
207209
instructions: Dict[str, Any]
208210

209211
# Dynamic values
210-
llm: ChatGoogleGenerativeAI # Main LLM
211-
check_llm: ChatGoogleGenerativeAI # LLM for validation/extraction
212+
llm: BaseChatModel # Use generic BaseChatModel
213+
check_llm: BaseChatModel # Use generic BaseChatModel
212214
vector_store: VectorStore
213215
summarize_chain: Any # Type hint could be improved
214216

@@ -356,8 +358,11 @@ async def close(self):
356358
(Exception, ResourceExhausted)
357359
), # Retry on general exceptions and specifically on ResourceExhausted (429)
358360
)
359-
async def _get_llm(self, model_name: str) -> ChatGoogleGenerativeAI:
360-
"""Gets or creates a ChatGoogleGenerativeAI instance with caching, rate limiting, and retry logic."""
361+
async def _get_llm(self, model_name: str) -> BaseChatModel:
362+
"""Gets or creates a LLM instance with caching, rate limiting, and retry logic.
363+
364+
Handles both Google Gemini and OpenRouter models.
365+
"""
361366
async with self._lock: # Protect access to the class-level cache
362367
if model_name in self._llm_cache:
363368
self.logger.debug(f"LLM Cache HIT for model: {model_name}")
@@ -367,45 +372,108 @@ async def _get_llm(self, model_name: str) -> ChatGoogleGenerativeAI:
367372
f"LLM Cache MISS for model: {model_name}. Creating new instance."
368373
)
369374

370-
is_pro_model = "pro" in model_name
371-
rpm = GEMINI_PRO_RPM if is_pro_model else GEMINI_FLASH_RPM
372-
tpm = GEMINI_PRO_TPM if is_pro_model else GEMINI_FLASH_TPM
373-
rpd = GEMINI_PRO_RPD if is_pro_model else GEMINI_FLASH_RPD
375+
llm_instance: BaseChatModel
376+
api_key_to_use: Optional[str]
374377

375-
rate_limiter = StreamingRateLimiter(rpm, tpm, rpd)
378+
if model_name.startswith("openrouter/"):
379+
# --- Handle OpenRouter Models ---
380+
api_key_to_use = await self.api_key_manager.get_openrouter_api_key()
381+
if not api_key_to_use:
382+
self.logger.error("OpenRouter API key not set.")
383+
raise ValueError(
384+
"OpenRouter API key not set. Please set it in the settings."
385+
)
376386

377-
try:
378-
llm = ChatGoogleGenerativeAI(
379-
model=model_name,
380-
google_api_key=self.api_key,
381-
temperature=float(
382-
self.model_settings.get("temperature", 0.7)
383-
), # Ensure float
384-
max_output_tokens=self.MAX_OUTPUT_TOKENS,
385-
# max_input_tokens not directly supported, handled via context truncation
386-
convert_system_message_to_human=True, # Often needed for Gemini
387-
streaming=True, # Keep streaming enabled
388-
callbacks=[rate_limiter],
389-
# Caching enabled globally via set_llm_cache
390-
)
391-
self._llm_cache[model_name] = llm
392-
self.logger.info(
393-
f"LLM instance created and cached for model: {model_name}"
394-
)
395-
return llm
396-
except Exception as e:
397-
self.logger.error(
398-
f"Failed to create LLM instance for {model_name}: {e}",
399-
exc_info=True,
400-
)
401-
raise # Re-raise the exception after logging
387+
openrouter_model_name = model_name.split("/", 1)[
388+
1
389+
] # e.g., openai/gpt-4o
390+
openrouter_base_url = "https://openrouter.ai/api/v1"
391+
# Optional headers for OpenRouter ranking
392+
# TODO: Make these configurable if desired
393+
site_url = os.getenv(
394+
"YOUR_SITE_URL", "https://www.op.scrllwise.com/"
395+
) # Placeholder
396+
site_name = os.getenv("YOUR_SITE_NAME", "Scrollwise AI") # Placeholder
397+
extra_headers = {
398+
"HTTP-Referer": site_url,
399+
"X-Title": site_name,
400+
}
401+
402+
try:
403+
llm_instance = ChatOpenAI(
404+
model=openrouter_model_name,
405+
openai_api_key=api_key_to_use,
406+
openai_api_base=openrouter_base_url,
407+
temperature=float(self.model_settings.get("temperature", 0.7)),
408+
max_tokens=self.MAX_OUTPUT_TOKENS, # Corresponds to max_tokens for OpenAI
409+
streaming=True, # Keep streaming enabled
410+
# Pass extra headers via model_kwargs
411+
model_kwargs={"extra_headers": extra_headers},
412+
# Caching enabled globally via set_llm_cache
413+
# No specific rate limiter applied here yet, OpenRouter handles its own limits.
414+
# callbacks=[],
415+
)
416+
self.logger.info(
417+
f"ChatOpenAI instance created for OpenRouter model: {openrouter_model_name}"
418+
)
419+
420+
except Exception as e:
421+
self.logger.error(
422+
f"Failed to create ChatOpenAI instance for {model_name}: {e}",
423+
exc_info=True,
424+
)
425+
raise
426+
427+
else:
428+
# --- Handle Google Gemini Models (Existing Logic) ---
429+
api_key_to_use = await self.api_key_manager.get_api_key()
430+
if not api_key_to_use:
431+
self.logger.error("Google API key not set.")
432+
raise ValueError(
433+
"Google API key not set. Please set it in the settings."
434+
)
435+
436+
is_pro_model = "pro" in model_name
437+
rpm = GEMINI_PRO_RPM if is_pro_model else GEMINI_FLASH_RPM
438+
tpm = GEMINI_PRO_TPM if is_pro_model else GEMINI_FLASH_TPM
439+
rpd = GEMINI_PRO_RPD if is_pro_model else GEMINI_FLASH_RPD
440+
441+
rate_limiter = StreamingRateLimiter(rpm, tpm, rpd)
442+
443+
try:
444+
llm_instance = ChatGoogleGenerativeAI(
445+
model=model_name,
446+
google_api_key=api_key_to_use,
447+
temperature=float(
448+
self.model_settings.get("temperature", 0.7)
449+
), # Ensure float
450+
max_output_tokens=self.MAX_OUTPUT_TOKENS,
451+
convert_system_message_to_human=True, # Keep for Gemini
452+
streaming=True, # Keep streaming enabled
453+
callbacks=[rate_limiter],
454+
# Caching enabled globally via set_llm_cache
455+
)
456+
self.logger.info(
457+
f"ChatGoogleGenerativeAI instance created for model: {model_name}"
458+
)
459+
except Exception as e:
460+
self.logger.error(
461+
f"Failed to create LLM instance for {model_name}: {e}",
462+
exc_info=True,
463+
)
464+
raise
465+
466+
# Add the created instance to the cache
467+
self._llm_cache[model_name] = llm_instance
468+
return llm_instance
402469

403470
async def _get_api_key(self) -> str:
471+
# This method specifically gets the GOOGLE key now
404472
api_key = await self.api_key_manager.get_api_key() # Removed user_id
405473
if not api_key:
406-
self.logger.error(f"API key not found for local instance.")
474+
self.logger.error(f"Google API key not found for local instance.")
407475
raise ValueError(
408-
"API key not set. Please set your API key in the settings."
476+
"Google API key not set. Please set your API key in the settings."
409477
)
410478
return api_key
411479

backend/api_key_manager.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,65 +54,154 @@ def decrypt_data(self, encrypted_data: str) -> str:
5454

5555

5656
class ApiKeyManager:
57+
# Existing file for Google API Key
5758
API_KEY_FILE = Path("./api_key.dat")
59+
# New file for OpenRouter API Key
60+
OPENROUTER_API_KEY_FILE = Path("./openrouter_api_key.dat")
5861

5962
def __init__(self, security_manager: SecurityManager):
6063
self.security_manager = security_manager
6164

65+
# --- Google API Key Methods ---
66+
6267
async def save_api_key(self, api_key: str) -> None:
68+
"""Saves the Google API key."""
6369
if api_key is None:
6470
logger.error("API key cannot be None")
6571
raise HTTPException(status_code=400, detail="API key cannot be None")
6672
try:
6773
encrypted_key = self.security_manager.encrypt_data(api_key)
6874
self.API_KEY_FILE.write_text(encrypted_key)
69-
logger.info(f"API key saved locally to {self.API_KEY_FILE}")
75+
logger.info(f"Google API key saved locally to {self.API_KEY_FILE}")
7076
except IOError as e:
71-
logger.error(f"Error writing API key file: {e}")
72-
raise HTTPException(status_code=500, detail="Error saving API key file")
77+
logger.error(f"Error writing Google API key file: {e}")
78+
raise HTTPException(
79+
status_code=500, detail="Error saving Google API key file"
80+
)
7381
except Exception as e:
74-
logger.error(f"Error saving API key: {e}")
75-
raise HTTPException(status_code=500, detail="Error saving API key")
82+
logger.error(f"Error saving Google API key: {e}")
83+
raise HTTPException(status_code=500, detail="Error saving Google API key")
7684

7785
async def get_api_key(self) -> Optional[str]:
86+
"""Gets the Google API key."""
7887
try:
7988
if not self.API_KEY_FILE.exists():
80-
logger.info("Local API key file not found.")
89+
logger.info("Local Google API key file not found.")
8190
return None
8291

8392
encrypted_key = self.API_KEY_FILE.read_text()
8493
if not encrypted_key:
85-
logger.warning("Local API key file is empty.")
94+
logger.warning("Local Google API key file is empty.")
8695
return None
8796

8897
try:
8998
decrypted_key = self.security_manager.decrypt_data(encrypted_key)
9099
return decrypted_key
91100
except ValueError as e:
92-
logger.error(f"Decryption failed for local API key: {e}")
101+
logger.error(f"Decryption failed for local Google API key: {e}")
93102
return None
94103
except Exception as e:
95-
logger.error(f"Unexpected error during decryption: {e}")
104+
logger.error(f"Unexpected error during Google key decryption: {e}")
96105
return None
97106

98107
except IOError as e:
99-
logger.error(f"Error reading API key file: {e}")
108+
logger.error(f"Error reading Google API key file: {e}")
100109
return None
101110
except Exception as e:
102-
logger.error(f"Error retrieving API key: {e}")
111+
logger.error(f"Error retrieving Google API key: {e}")
103112
return None
104113

105114
async def remove_api_key(self) -> None:
106-
"""Removes the locally stored API key file."""
115+
"""Removes the locally stored Google API key file."""
107116
try:
108117
if self.API_KEY_FILE.exists():
109118
self.API_KEY_FILE.unlink()
110-
logger.info(f"Local API key file removed: {self.API_KEY_FILE}")
119+
logger.info(f"Local Google API key file removed: {self.API_KEY_FILE}")
120+
else:
121+
logger.info("Local Google API key file not found, nothing to remove.")
122+
except IOError as e:
123+
logger.error(f"Error removing Google API key file: {e}")
124+
raise HTTPException(
125+
status_code=500, detail="Error removing Google API key file"
126+
)
127+
except Exception as e:
128+
logger.error(f"Error removing Google API key: {e}")
129+
raise HTTPException(status_code=500, detail="Error removing Google API key")
130+
131+
# --- OpenRouter API Key Methods ---
132+
133+
async def save_openrouter_api_key(self, api_key: str) -> None:
134+
"""Saves the OpenRouter API key."""
135+
if api_key is None:
136+
logger.error("OpenRouter API key cannot be None")
137+
raise HTTPException(
138+
status_code=400, detail="OpenRouter API key cannot be None"
139+
)
140+
try:
141+
encrypted_key = self.security_manager.encrypt_data(api_key)
142+
self.OPENROUTER_API_KEY_FILE.write_text(encrypted_key)
143+
logger.info(
144+
f"OpenRouter API key saved locally to {self.OPENROUTER_API_KEY_FILE}"
145+
)
146+
except IOError as e:
147+
logger.error(f"Error writing OpenRouter API key file: {e}")
148+
raise HTTPException(
149+
status_code=500, detail="Error saving OpenRouter API key file"
150+
)
151+
except Exception as e:
152+
logger.error(f"Error saving OpenRouter API key: {e}")
153+
raise HTTPException(
154+
status_code=500, detail="Error saving OpenRouter API key"
155+
)
156+
157+
async def get_openrouter_api_key(self) -> Optional[str]:
158+
"""Gets the OpenRouter API key."""
159+
try:
160+
if not self.OPENROUTER_API_KEY_FILE.exists():
161+
logger.info("Local OpenRouter API key file not found.")
162+
return None
163+
164+
encrypted_key = self.OPENROUTER_API_KEY_FILE.read_text()
165+
if not encrypted_key:
166+
logger.warning("Local OpenRouter API key file is empty.")
167+
return None
168+
169+
try:
170+
decrypted_key = self.security_manager.decrypt_data(encrypted_key)
171+
return decrypted_key
172+
except ValueError as e:
173+
logger.error(f"Decryption failed for local OpenRouter API key: {e}")
174+
return None
175+
except Exception as e:
176+
logger.error(f"Unexpected error during OpenRouter key decryption: {e}")
177+
return None
178+
179+
except IOError as e:
180+
logger.error(f"Error reading OpenRouter API key file: {e}")
181+
return None
182+
except Exception as e:
183+
logger.error(f"Error retrieving OpenRouter API key: {e}")
184+
return None
185+
186+
async def remove_openrouter_api_key(self) -> None:
187+
"""Removes the locally stored OpenRouter API key file."""
188+
try:
189+
if self.OPENROUTER_API_KEY_FILE.exists():
190+
self.OPENROUTER_API_KEY_FILE.unlink()
191+
logger.info(
192+
f"Local OpenRouter API key file removed: {self.OPENROUTER_API_KEY_FILE}"
193+
)
111194
else:
112-
logger.info("Local API key file not found, nothing to remove.")
195+
logger.info(
196+
"Local OpenRouter API key file not found, nothing to remove."
197+
)
113198
except IOError as e:
114-
logger.error(f"Error removing API key file: {e}")
115-
raise HTTPException(status_code=500, detail="Error removing API key file")
199+
logger.error(f"Error removing OpenRouter API key file: {e}")
200+
raise HTTPException(
201+
status_code=500, detail="Error removing OpenRouter API key file"
202+
)
116203
except Exception as e:
117-
logger.error(f"Error removing API key: {e}")
118-
raise HTTPException(status_code=500, detail="Error removing API key")
204+
logger.error(f"Error removing OpenRouter API key: {e}")
205+
raise HTTPException(
206+
status_code=500, detail="Error removing OpenRouter API key"
207+
)

backend/requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ uvicorn
1919
langchain-qdrant
2020
psutil
2121
qdrant-client
22-
python-multipart
22+
python-multipart
23+
langchain_openai
24+
httpx
25+
tiktoken

0 commit comments

Comments
 (0)