Skip to content

Commit 7066bbf

Browse files
committed
in progress checkpoint: handles vertex ai message conversion
1 parent d8ff6f5 commit 7066bbf

File tree

3 files changed

+612
-126
lines changed

3 files changed

+612
-126
lines changed

src/api/app.py

Lines changed: 6 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
11
import logging
2-
import requests
2+
import os
33
import uvicorn
4-
from fastapi import FastAPI, Request, Response
4+
5+
from fastapi import FastAPI, Request
56
from fastapi.exceptions import RequestValidationError
67
from fastapi.middleware.cors import CORSMiddleware
78
from fastapi.responses import PlainTextResponse
89
from mangum import Mangum
9-
import httpx
10-
import json
11-
import os
12-
from contextlib import asynccontextmanager
13-
14-
from api.setting import API_ROUTE_PREFIX, DESCRIPTION, GCP_PROJECT_ID, GCP_REGION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION
1510

16-
from google.auth import default
17-
from google.auth.transport.requests import Request as AuthRequest
18-
19-
from api.modelmapper import get_model, load_model_map
20-
21-
# GCP credentials and project details
22-
credentials = None
23-
project_id = None
24-
location = None
11+
from api.setting import API_ROUTE_PREFIX, DESCRIPTION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION
12+
from api.modelmapper import load_model_map
13+
from api.routers.vertex import handle_proxy
2514

2615
def is_aws():
2716
env = os.getenv("AWS_EXECUTION_ENV")
@@ -43,115 +32,6 @@ def is_aws():
4332
if USE_MODEL_MAPPING:
4433
load_model_map()
4534

46-
47-
def get_gcp_project_details():
48-
from google.auth import default
49-
50-
# Try metadata server for region
51-
credentials = None
52-
project_id = GCP_PROJECT_ID
53-
location = GCP_REGION
54-
55-
try:
56-
credentials, project = default()
57-
if not project_id:
58-
project_id = project
59-
60-
if not location:
61-
zone = requests.get(
62-
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
63-
headers={"Metadata-Flavor": "Google"},
64-
timeout=1
65-
).text
66-
location = zone.split("/")[-1].rsplit("-", 1)[0]
67-
68-
except Exception:
69-
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.")
70-
71-
return credentials, project_id, location
72-
73-
if not is_aws():
74-
credentials, project_id, location = get_gcp_project_details()
75-
76-
# Utility: get service account access token
77-
def get_access_token():
78-
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
79-
auth_request = AuthRequest()
80-
credentials.refresh(auth_request)
81-
return credentials.token
82-
83-
def get_gcp_target(path):
84-
"""
85-
Check if the environment variable is set to use GCP.
86-
"""
87-
if os.getenv("PROXY_TARGET"):
88-
return os.getenv("PROXY_TARGET")
89-
else:
90-
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{path.lstrip('/')}".rstrip("/")
91-
92-
def get_header(request, path):
93-
if "chat/completions" in path:
94-
path = path.replace("chat/completions", "endpoints/openapi/chat/completions")
95-
96-
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX)
97-
target_url = get_gcp_target(path_no_prefix)
98-
99-
# remove hop-by-hop headers
100-
headers = {
101-
k: v for k, v in request.headers.items()
102-
if k.lower() not in {"host", "content-length", "accept-encoding", "connection", "authorization"}
103-
}
104-
105-
# Fetch service account token
106-
access_token = get_access_token()
107-
headers["Authorization"] = f"Bearer {access_token}"
108-
return target_url,headers
109-
110-
async def handle_proxy(request: Request, path: str):
111-
# Build safe target URL
112-
target_url, headers = get_header(request, path)
113-
114-
try:
115-
content = await request.body()
116-
117-
if USE_MODEL_MAPPING:
118-
data = json.loads(content)
119-
if "model" in data:
120-
request_model = data.get("model", None)
121-
model = get_model("gcp", request_model)
122-
123-
if model != None and model != request_model and "publishers/google/" in model:
124-
model = f"google/{model.split('/')[-1]}"
125-
126-
data["model"]= model
127-
content = json.dumps(data)
128-
129-
async with httpx.AsyncClient() as client:
130-
response = await client.request(
131-
method=request.method,
132-
url=target_url,
133-
headers=headers,
134-
content=content,
135-
params=request.query_params,
136-
timeout=30.0,
137-
)
138-
except httpx.RequestError as e:
139-
logging.error(f"Proxy request failed: {e}")
140-
return Response(status_code=502, content=f"Upstream request failed: {e}")
141-
142-
# remove hop-by-hop headers
143-
response_headers = {
144-
k: v for k, v in response.headers.items()
145-
if k.lower() not in {"content-encoding", "transfer-encoding", "connection"}
146-
}
147-
148-
return Response(
149-
content=response.content,
150-
status_code=response.status_code,
151-
headers=response_headers,
152-
media_type=response.headers.get("content-type", "application/octet-stream"),
153-
)
154-
15535
config = {
15636
"title": TITLE,
15737
"description": DESCRIPTION,

0 commit comments

Comments
 (0)