3
3
from typing import Dict , Any , List , Tuple , Optional , Union , TypedDict
4
4
from dotenv import load_dotenv
5
5
from langchain_google_genai import ChatGoogleGenerativeAI
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core .language_models .chat_models import BaseChatModel
6
8
from langchain .prompts import ChatPromptTemplate , MessagesPlaceholder
7
9
from langchain .schema import StrOutputParser
8
10
from langchain_core .messages import HumanMessage , AIMessage , BaseMessage
@@ -207,8 +209,8 @@ class ChapterGenerationState(TypedDict):
207
209
instructions : Dict [str , Any ]
208
210
209
211
# Dynamic values
210
- llm : ChatGoogleGenerativeAI # Main LLM
211
- check_llm : ChatGoogleGenerativeAI # LLM for validation/extraction
212
+ llm : BaseChatModel # Use generic BaseChatModel
213
+ check_llm : BaseChatModel # Use generic BaseChatModel
212
214
vector_store : VectorStore
213
215
summarize_chain : Any # Type hint could be improved
214
216
@@ -356,8 +358,11 @@ async def close(self):
356
358
(Exception , ResourceExhausted )
357
359
), # Retry on general exceptions and specifically on ResourceExhausted (429)
358
360
)
359
- async def _get_llm (self , model_name : str ) -> ChatGoogleGenerativeAI :
360
- """Gets or creates a ChatGoogleGenerativeAI instance with caching, rate limiting, and retry logic."""
361
+ async def _get_llm (self , model_name : str ) -> BaseChatModel :
362
+ """Gets or creates a LLM instance with caching, rate limiting, and retry logic.
363
+
364
+ Handles both Google Gemini and OpenRouter models.
365
+ """
361
366
async with self ._lock : # Protect access to the class-level cache
362
367
if model_name in self ._llm_cache :
363
368
self .logger .debug (f"LLM Cache HIT for model: { model_name } " )
@@ -367,45 +372,108 @@ async def _get_llm(self, model_name: str) -> ChatGoogleGenerativeAI:
367
372
f"LLM Cache MISS for model: { model_name } . Creating new instance."
368
373
)
369
374
370
- is_pro_model = "pro" in model_name
371
- rpm = GEMINI_PRO_RPM if is_pro_model else GEMINI_FLASH_RPM
372
- tpm = GEMINI_PRO_TPM if is_pro_model else GEMINI_FLASH_TPM
373
- rpd = GEMINI_PRO_RPD if is_pro_model else GEMINI_FLASH_RPD
375
+ llm_instance : BaseChatModel
376
+ api_key_to_use : Optional [str ]
374
377
375
- rate_limiter = StreamingRateLimiter (rpm , tpm , rpd )
378
+ if model_name .startswith ("openrouter/" ):
379
+ # --- Handle OpenRouter Models ---
380
+ api_key_to_use = await self .api_key_manager .get_openrouter_api_key ()
381
+ if not api_key_to_use :
382
+ self .logger .error ("OpenRouter API key not set." )
383
+ raise ValueError (
384
+ "OpenRouter API key not set. Please set it in the settings."
385
+ )
376
386
377
- try :
378
- llm = ChatGoogleGenerativeAI (
379
- model = model_name ,
380
- google_api_key = self .api_key ,
381
- temperature = float (
382
- self .model_settings .get ("temperature" , 0.7 )
383
- ), # Ensure float
384
- max_output_tokens = self .MAX_OUTPUT_TOKENS ,
385
- # max_input_tokens not directly supported, handled via context truncation
386
- convert_system_message_to_human = True , # Often needed for Gemini
387
- streaming = True , # Keep streaming enabled
388
- callbacks = [rate_limiter ],
389
- # Caching enabled globally via set_llm_cache
390
- )
391
- self ._llm_cache [model_name ] = llm
392
- self .logger .info (
393
- f"LLM instance created and cached for model: { model_name } "
394
- )
395
- return llm
396
- except Exception as e :
397
- self .logger .error (
398
- f"Failed to create LLM instance for { model_name } : { e } " ,
399
- exc_info = True ,
400
- )
401
- raise # Re-raise the exception after logging
387
+ openrouter_model_name = model_name .split ("/" , 1 )[
388
+ 1
389
+ ] # e.g., openai/gpt-4o
390
+ openrouter_base_url = "https://openrouter.ai/api/v1"
391
+ # Optional headers for OpenRouter ranking
392
+ # TODO: Make these configurable if desired
393
+ site_url = os .getenv (
394
+ "YOUR_SITE_URL" , "https://www.op.scrllwise.com/"
395
+ ) # Placeholder
396
+ site_name = os .getenv ("YOUR_SITE_NAME" , "Scrollwise AI" ) # Placeholder
397
+ extra_headers = {
398
+ "HTTP-Referer" : site_url ,
399
+ "X-Title" : site_name ,
400
+ }
401
+
402
+ try :
403
+ llm_instance = ChatOpenAI (
404
+ model = openrouter_model_name ,
405
+ openai_api_key = api_key_to_use ,
406
+ openai_api_base = openrouter_base_url ,
407
+ temperature = float (self .model_settings .get ("temperature" , 0.7 )),
408
+ max_tokens = self .MAX_OUTPUT_TOKENS , # Corresponds to max_tokens for OpenAI
409
+ streaming = True , # Keep streaming enabled
410
+ # Pass extra headers via model_kwargs
411
+ model_kwargs = {"extra_headers" : extra_headers },
412
+ # Caching enabled globally via set_llm_cache
413
+ # No specific rate limiter applied here yet, OpenRouter handles its own limits.
414
+ # callbacks=[],
415
+ )
416
+ self .logger .info (
417
+ f"ChatOpenAI instance created for OpenRouter model: { openrouter_model_name } "
418
+ )
419
+
420
+ except Exception as e :
421
+ self .logger .error (
422
+ f"Failed to create ChatOpenAI instance for { model_name } : { e } " ,
423
+ exc_info = True ,
424
+ )
425
+ raise
426
+
427
+ else :
428
+ # --- Handle Google Gemini Models (Existing Logic) ---
429
+ api_key_to_use = await self .api_key_manager .get_api_key ()
430
+ if not api_key_to_use :
431
+ self .logger .error ("Google API key not set." )
432
+ raise ValueError (
433
+ "Google API key not set. Please set it in the settings."
434
+ )
435
+
436
+ is_pro_model = "pro" in model_name
437
+ rpm = GEMINI_PRO_RPM if is_pro_model else GEMINI_FLASH_RPM
438
+ tpm = GEMINI_PRO_TPM if is_pro_model else GEMINI_FLASH_TPM
439
+ rpd = GEMINI_PRO_RPD if is_pro_model else GEMINI_FLASH_RPD
440
+
441
+ rate_limiter = StreamingRateLimiter (rpm , tpm , rpd )
442
+
443
+ try :
444
+ llm_instance = ChatGoogleGenerativeAI (
445
+ model = model_name ,
446
+ google_api_key = api_key_to_use ,
447
+ temperature = float (
448
+ self .model_settings .get ("temperature" , 0.7 )
449
+ ), # Ensure float
450
+ max_output_tokens = self .MAX_OUTPUT_TOKENS ,
451
+ convert_system_message_to_human = True , # Keep for Gemini
452
+ streaming = True , # Keep streaming enabled
453
+ callbacks = [rate_limiter ],
454
+ # Caching enabled globally via set_llm_cache
455
+ )
456
+ self .logger .info (
457
+ f"ChatGoogleGenerativeAI instance created for model: { model_name } "
458
+ )
459
+ except Exception as e :
460
+ self .logger .error (
461
+ f"Failed to create LLM instance for { model_name } : { e } " ,
462
+ exc_info = True ,
463
+ )
464
+ raise
465
+
466
+ # Add the created instance to the cache
467
+ self ._llm_cache [model_name ] = llm_instance
468
+ return llm_instance
402
469
403
470
async def _get_api_key (self ) -> str :
471
+ # This method specifically gets the GOOGLE key now
404
472
api_key = await self .api_key_manager .get_api_key () # Removed user_id
405
473
if not api_key :
406
- self .logger .error (f"API key not found for local instance." )
474
+ self .logger .error (f"Google API key not found for local instance." )
407
475
raise ValueError (
408
- "API key not set. Please set your API key in the settings."
476
+ "Google API key not set. Please set your API key in the settings."
409
477
)
410
478
return api_key
411
479
0 commit comments