Skip to content

Handle Anthropic messages as if they were chat/completions #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 8 additions & 139 deletions src/api/app.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import logging
import requests
import os
import uvicorn
from fastapi import FastAPI, Request, Response

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from mangum import Mangum
import httpx
import json
import os
from contextlib import asynccontextmanager

from api.setting import API_ROUTE_PREFIX, DESCRIPTION, GCP_ENDPOINT, GCP_PROJECT_ID, GCP_REGION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION

from google.auth import default
from google.auth.transport.requests import Request as AuthRequest

from api.modelmapper import get_model, load_model_map

# GCP credentials and project details
credentials = None
project_id = None
location = None
from api.setting import API_ROUTE_PREFIX, DESCRIPTION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION
from api.modelmapper import load_model_map
from api.routers.vertex import handle_proxy

def is_aws():
env = os.getenv("AWS_EXECUTION_ENV")
Expand All @@ -43,124 +32,6 @@ def is_aws():
if USE_MODEL_MAPPING:
load_model_map()


def get_gcp_project_details():
from google.auth import default

# Try metadata server for region
credentials = None
project_id = GCP_PROJECT_ID
location = GCP_REGION

try:
credentials, project = default()
if not project_id:
project_id = project

if not location:
zone = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
headers={"Metadata-Flavor": "Google"},
timeout=1
).text
location = zone.split("/")[-1].rsplit("-", 1)[0]

except Exception:
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.")

return credentials, project_id, location

if not is_aws():
credentials, project_id, location = get_gcp_project_details()

# Utility: get service account access token
def get_access_token():
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = AuthRequest()
credentials.refresh(auth_request)
return credentials.token

def get_gcp_target():
"""
Check if the environment variable is set to use GCP.
"""
if project_id and location:
return f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/{GCP_ENDPOINT}/"

return None

def get_proxy_target():
"""
Check if the environment variable is set to use a proxy.
"""
proxy_target = os.getenv("PROXY_TARGET")
if proxy_target:
return proxy_target
gcp_target = get_gcp_target()
if gcp_target:
return gcp_target

return None

def get_header(request, path):
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX)
target_url = f"{proxy_target.rstrip('/')}/{path_no_prefix.lstrip('/')}".rstrip("/")

# remove hop-by-hop headers
headers = {
k: v for k, v in request.headers.items()
if k.lower() not in {"host", "content-length", "accept-encoding", "connection", "authorization"}
}

# Fetch service account token
access_token = get_access_token()
headers["Authorization"] = f"Bearer {access_token}"
return target_url,headers

async def handle_proxy(request: Request, path: str):
# Build safe target URL
target_url, headers = get_header(request, path)

try:
content = await request.body()
data = json.loads(content)

if USE_MODEL_MAPPING:
request_model = data.get("model", None)
model = get_model("gcp", request_model)

if model != None and model != request_model and "publishers/google/" in model:
model = f"google/{model.split('/')[-1]}"

data["model"] = model
content = json.dumps(data)

async with httpx.AsyncClient() as client:
response = await client.request(
method=request.method,
url=target_url,
headers=headers,
content=content,
params=request.query_params,
timeout=30.0,
)
except httpx.RequestError as e:
logging.error(f"Proxy request failed: {e}")
return Response(status_code=502, content=f"Upstream request failed: {e}")

# remove hop-by-hop headers
response_headers = {
k: v for k, v in response.headers.items()
if k.lower() not in {"content-encoding", "transfer-encoding", "connection"}
}

return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type", "application/octet-stream"),
)

config = {
"title": TITLE,
"description": DESCRIPTION,
Expand All @@ -173,8 +44,6 @@ async def handle_proxy(request: Request, path: str):
format="%(asctime)s [%(levelname)s] %(message)s",
)

proxy_target = get_proxy_target()

app = FastAPI(**config)
app.add_middleware(
CORSMiddleware,
Expand All @@ -184,8 +53,8 @@ async def handle_proxy(request: Request, path: str):
allow_headers=["*"],
)

if provider != "aws" and proxy_target:
logging.info(f"Proxy target set to: {proxy_target}")
if provider != "aws":
logging.info(f"Proxy target set to: GCP")
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def proxy(request: Request, path: str):
return await handle_proxy(request, path)
Expand Down
2 changes: 1 addition & 1 deletion src/api/modelmapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def load_model_map():

def get_model(provider, model):
provider = provider.lower()
model = model.lower()
model = model.lower().removesuffix(":latest")

available_models = _model_map.get(provider, {})
return available_models.get(model, model)
Expand Down
1 change: 1 addition & 0 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def validate(self, chat_request: ChatRequest):
f"Unsupported model '{chat_request.model}'. "
f"list of known models: {bedrock_model_list.keys()}"
)
logger.error(error)

if error:
raise HTTPException(
Expand Down
Loading