|
12 | 12 | from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError
|
13 | 13 | from app.core.telemetry_integration import track_llm_operation
|
14 | 14 | from app.core.config import _get_caii_token
|
| 15 | +import os |
| 16 | +from dotenv import load_dotenv |
| 17 | +load_dotenv() |
| 18 | +import google.generativeai as genai |
15 | 19 |
|
16 | 20 |
|
17 | 21 |
|
@@ -154,11 +158,21 @@ def _extract_json_from_text(self, text: str) -> List[Dict[str, Any]]:
|
154 | 158 |
|
155 | 159 |
|
156 | 160 | #@track_llm_operation("generate")
|
157 |
| - def generate_response(self, prompt: str, retry_with_reduced_tokens: bool = True, request_id = None) -> List[Dict[str, str]]: |
| 161 | + def generate_response( |
| 162 | + self, |
| 163 | + prompt: str, |
| 164 | + retry_with_reduced_tokens: bool = True, |
| 165 | + request_id: Optional[str] = None, |
| 166 | + ): |
158 | 167 | if self.inference_type == "aws_bedrock":
|
159 | 168 | return self._handle_bedrock_request(prompt, retry_with_reduced_tokens)
|
160 |
| - elif self.inference_type == "CAII": |
| 169 | + if self.inference_type == "CAII": |
161 | 170 | return self._handle_caii_request(prompt)
|
| 171 | + if self.inference_type == "openai": |
| 172 | + return self._handle_openai_request(prompt) |
| 173 | + if self.inference_type == "gemini": |
| 174 | + return self._handle_gemini_request(prompt) |
| 175 | + raise ModelHandlerError(f"Unsupported inference_type={self.inference_type}", 400) |
162 | 176 |
|
163 | 177 | def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
|
164 | 178 | """Handle Bedrock requests with retry logic"""
|
@@ -278,6 +292,50 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
|
278 | 292 | raise ModelHandlerError(f"Failed after {self.MAX_RETRIES} retries: {str(last_exception)}", status_code=500)
|
279 | 293 |
|
280 | 294 |
|
| 295 | + # ---------- OpenAI ------------------------------------------------------- |
| 296 | + def _handle_openai_request(self, prompt: str): |
| 297 | + try: |
| 298 | + client = OpenAI( |
| 299 | + api_key=os.getenv("OPENAI_API_KEY"), |
| 300 | + base_url=os.getenv("OPENAI_API_BASE", None) or None, |
| 301 | + ) |
| 302 | + completion = client.chat.completions.create( |
| 303 | + model=self.model_id, |
| 304 | + messages=[{"role": "user", "content": prompt}], |
| 305 | + max_tokens=self.model_params.max_tokens, |
| 306 | + temperature=self.model_params.temperature, |
| 307 | + top_p=self.model_params.top_p, |
| 308 | + stream=False, |
| 309 | + ) |
| 310 | + text = completion.choices[0].message.content |
| 311 | + return self._extract_json_from_text(text) if not self.custom_p else text |
| 312 | + except Exception as e: |
| 313 | + raise ModelHandlerError(f"OpenAI request failed: {e}", 500) |
| 314 | + |
| 315 | + # ---------- Gemini ------------------------------------------------------- |
| 316 | + def _handle_gemini_request(self, prompt: str): |
| 317 | + if genai is None: |
| 318 | + raise ModelHandlerError( |
| 319 | + "google-generativeai library not installed — `pip install google-generativeai`", |
| 320 | + 500, |
| 321 | + ) |
| 322 | + try: |
| 323 | + genai.configure(api_key=os.getenv("GEMINI_API_KEY")) |
| 324 | + model = genai.GenerativeModel(self.model_id) # e.g. 'gemini-1.5-pro-latest' |
| 325 | + resp = model.generate_content( |
| 326 | + prompt, |
| 327 | + generation_config={ |
| 328 | + "max_output_tokens": self.model_params.max_tokens, |
| 329 | + "temperature": self.model_params.temperature, |
| 330 | + "top_p": self.model_params.top_p, |
| 331 | + }, |
| 332 | + ) |
| 333 | + text = resp.text |
| 334 | + return self._extract_json_from_text(text) if not self.custom_p else text |
| 335 | + except Exception as e: |
| 336 | + raise ModelHandlerError(f"Gemini request failed: {e}", 500) |
| 337 | + |
| 338 | + |
281 | 339 | def _handle_caii_request(self, prompt: str):
|
282 | 340 | """Original CAII implementation"""
|
283 | 341 | try:
|
|
0 commit comments