Skip to content

Commit c694da4

Browse files
committed
support for openai and gemini models at API layer
1 parent 3c0e6c5 commit c694da4

File tree

4 files changed

+437
-12
lines changed

4 files changed

+437
-12
lines changed

app/core/model_handlers.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError
1313
from app.core.telemetry_integration import track_llm_operation
1414
from app.core.config import _get_caii_token
15+
import os
16+
from dotenv import load_dotenv
17+
load_dotenv()
18+
import google.generativeai as genai
1519

1620

1721

@@ -154,11 +158,21 @@ def _extract_json_from_text(self, text: str) -> List[Dict[str, Any]]:
154158

155159

156160
#@track_llm_operation("generate")
157-
def generate_response(self, prompt: str, retry_with_reduced_tokens: bool = True, request_id = None) -> List[Dict[str, str]]:
161+
def generate_response(
162+
self,
163+
prompt: str,
164+
retry_with_reduced_tokens: bool = True,
165+
request_id: Optional[str] = None,
166+
):
158167
if self.inference_type == "aws_bedrock":
159168
return self._handle_bedrock_request(prompt, retry_with_reduced_tokens)
160-
elif self.inference_type == "CAII":
169+
if self.inference_type == "CAII":
161170
return self._handle_caii_request(prompt)
171+
if self.inference_type == "openai":
172+
return self._handle_openai_request(prompt)
173+
if self.inference_type == "gemini":
174+
return self._handle_gemini_request(prompt)
175+
raise ModelHandlerError(f"Unsupported inference_type={self.inference_type}", 400)
162176

163177
def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
164178
"""Handle Bedrock requests with retry logic"""
@@ -278,6 +292,50 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
278292
raise ModelHandlerError(f"Failed after {self.MAX_RETRIES} retries: {str(last_exception)}", status_code=500)
279293

280294

295+
# ---------- OpenAI -------------------------------------------------------
296+
def _handle_openai_request(self, prompt: str):
297+
try:
298+
client = OpenAI(
299+
api_key=os.getenv("OPENAI_API_KEY"),
300+
base_url=os.getenv("OPENAI_API_BASE", None) or None,
301+
)
302+
completion = client.chat.completions.create(
303+
model=self.model_id,
304+
messages=[{"role": "user", "content": prompt}],
305+
max_tokens=self.model_params.max_tokens,
306+
temperature=self.model_params.temperature,
307+
top_p=self.model_params.top_p,
308+
stream=False,
309+
)
310+
text = completion.choices[0].message.content
311+
return self._extract_json_from_text(text) if not self.custom_p else text
312+
except Exception as e:
313+
raise ModelHandlerError(f"OpenAI request failed: {e}", 500)
314+
315+
# ---------- Gemini -------------------------------------------------------
316+
def _handle_gemini_request(self, prompt: str):
317+
if genai is None:
318+
raise ModelHandlerError(
319+
"google-generativeai library not installed — `pip install google-generativeai`",
320+
500,
321+
)
322+
try:
323+
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
324+
model = genai.GenerativeModel(self.model_id) # e.g. 'gemini-1.5-pro-latest'
325+
resp = model.generate_content(
326+
prompt,
327+
generation_config={
328+
"max_output_tokens": self.model_params.max_tokens,
329+
"temperature": self.model_params.temperature,
330+
"top_p": self.model_params.top_p,
331+
},
332+
)
333+
text = resp.text
334+
return self._extract_json_from_text(text) if not self.custom_p else text
335+
except Exception as e:
336+
raise ModelHandlerError(f"Gemini request failed: {e}", 500)
337+
338+
281339
def _handle_caii_request(self, prompt: str):
282340
"""Original CAII implementation"""
283341
try:

app/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,9 @@ async def get_model_id():
731731

732732
models = {
733733
"aws_bedrock": bedrock_list,
734-
"CAII": []
734+
"CAII": [],
735+
"OpenAI" : [],
736+
"Google Gemini" : []
735737
}
736738

737739
return {"models": models}

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ dependencies = [
3636
"psutil==5.9.8",
3737
"pandas>=2.2.3",
3838
# ── new packages for data‑analysis layer ─────────────────
39-
"numpy>=1.24.0", # explicit, for dcor/scipy (already a transitive dep of pandas)
40-
"scipy>=1.12.0", # chi‑square, ANOVA, etc.
41-
"dcor>=0.6", # distance‑correlation metric
42-
"openpyxl>=3.1.2", # read .xlsx files
43-
"pyxlsb>=1.0.9", # read .xlsb files
39+
"numpy>=1.24.0", # explicit, for dcor/scipy (already a transitive dep of pandas)
40+
"scipy>=1.12.0", # chi‑square, ANOVA, etc.
41+
"dcor>=0.6", # distance‑correlation metric
42+
"openpyxl>=3.1.2", # read .xlsx files
43+
"pyxlsb>=1.0.9", # read .xlsb files
44+
"google>=3.0.0",
45+
"google-generativeai>=0.8.5",
46+
"google-genai>=1.2.0",
4447
]
4548

4649

0 commit comments

Comments
 (0)