Skip to content

Handle Anthropic messages as if they were chat/completions #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 8 additions & 139 deletions src/api/app.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import logging
import requests
import os
import uvicorn
from fastapi import FastAPI, Request, Response

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from mangum import Mangum
import httpx
import json
import os
from contextlib import asynccontextmanager

from api.setting import API_ROUTE_PREFIX, DESCRIPTION, GCP_ENDPOINT, GCP_PROJECT_ID, GCP_REGION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION

from google.auth import default
from google.auth.transport.requests import Request as AuthRequest

from api.modelmapper import get_model, load_model_map

# GCP credentials and project details
credentials = None
project_id = None
location = None
from api.setting import API_ROUTE_PREFIX, DESCRIPTION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION
from api.modelmapper import load_model_map
from api.routers.vertex import handle_proxy

def is_aws():
env = os.getenv("AWS_EXECUTION_ENV")
Expand All @@ -43,124 +32,6 @@ def is_aws():
if USE_MODEL_MAPPING:
load_model_map()


def get_gcp_project_details():
from google.auth import default

# Try metadata server for region
credentials = None
project_id = GCP_PROJECT_ID
location = GCP_REGION

try:
credentials, project = default()
if not project_id:
project_id = project

if not location:
zone = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
headers={"Metadata-Flavor": "Google"},
timeout=1
).text
location = zone.split("/")[-1].rsplit("-", 1)[0]

except Exception:
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.")

return credentials, project_id, location

if not is_aws():
credentials, project_id, location = get_gcp_project_details()

# Utility: get service account access token
def get_access_token():
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = AuthRequest()
credentials.refresh(auth_request)
return credentials.token

def get_gcp_target():
"""
Check if the environment variable is set to use GCP.
"""
if project_id and location:
return f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/{GCP_ENDPOINT}/"

return None

def get_proxy_target():
"""
Check if the environment variable is set to use a proxy.
"""
proxy_target = os.getenv("PROXY_TARGET")
if proxy_target:
return proxy_target
gcp_target = get_gcp_target()
if gcp_target:
return gcp_target

return None

def get_header(request, path):
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX)
target_url = f"{proxy_target.rstrip('/')}/{path_no_prefix.lstrip('/')}".rstrip("/")

# remove hop-by-hop headers
headers = {
k: v for k, v in request.headers.items()
if k.lower() not in {"host", "content-length", "accept-encoding", "connection", "authorization"}
}

# Fetch service account token
access_token = get_access_token()
headers["Authorization"] = f"Bearer {access_token}"
return target_url,headers

async def handle_proxy(request: Request, path: str):
# Build safe target URL
target_url, headers = get_header(request, path)

try:
content = await request.body()
data = json.loads(content)

if USE_MODEL_MAPPING:
request_model = data.get("model", None)
model = get_model("gcp", request_model)

if model != None and model != request_model and "publishers/google/" in model:
model = f"google/{model.split('/')[-1]}"

data["model"] = model
content = json.dumps(data)

async with httpx.AsyncClient() as client:
response = await client.request(
method=request.method,
url=target_url,
headers=headers,
content=content,
params=request.query_params,
timeout=30.0,
)
except httpx.RequestError as e:
logging.error(f"Proxy request failed: {e}")
return Response(status_code=502, content=f"Upstream request failed: {e}")

# remove hop-by-hop headers
response_headers = {
k: v for k, v in response.headers.items()
if k.lower() not in {"content-encoding", "transfer-encoding", "connection"}
}

return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type", "application/octet-stream"),
)

config = {
"title": TITLE,
"description": DESCRIPTION,
Expand All @@ -173,8 +44,6 @@ async def handle_proxy(request: Request, path: str):
format="%(asctime)s [%(levelname)s] %(message)s",
)

proxy_target = get_proxy_target()

app = FastAPI(**config)
app.add_middleware(
CORSMiddleware,
Expand All @@ -184,8 +53,8 @@ async def handle_proxy(request: Request, path: str):
allow_headers=["*"],
)

if provider != "aws" and proxy_target:
logging.info(f"Proxy target set to: {proxy_target}")
if provider != "aws":
logging.info(f"Proxy target set to: GCP")
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def proxy(request: Request, path: str):
return await handle_proxy(request, path)
Expand Down
199 changes: 199 additions & 0 deletions src/api/routers/test_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import pytest
import json
from unittest.mock import patch, MagicMock
from fastapi import Request
from starlette.datastructures import Headers, QueryParams
from fastapi import Response

import api.routers.vertex as vertex

@pytest.fixture
def dummy_request():
class DummyRequest:
def __init__(self, headers=None, body=None, method="POST", query_params=None):
self.headers = Headers(headers or {})
self._body = body or b'{}'
self.method = method
self.query_params = QueryParams(query_params or {})

async def body(self):
return self._body

return DummyRequest

def test_to_vertex_anthropic():
openai_messages = {
"messages": [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
]
}
result = vertex.to_vertex_anthropic(openai_messages)
assert result["anthropic_version"] == "vertex-2023-10-16"
assert result["max_tokens"] == 256
assert isinstance(result["messages"], list)
assert result["messages"][0]["role"] == "user"
assert result["messages"][0]["content"][0]["text"] == "Hello!"
assert result["messages"][1]["role"] == "assistant"
assert result["messages"][1]["content"][0]["text"] == "Hi there!"

def test_from_anthropic_to_openai_response():
msg = json.dumps({
"id": "abc123",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}, {"type": "text", "text": "Bye!"}],
"stop_reason": "stop",
"usage": {"prompt_tokens": 5, "completion_tokens": 2}
})
result = json.loads(vertex.from_anthropic_to_openai_response(msg))
assert result["id"] == "abc123"
assert result["object"] == "chat.completion"
assert len(result["choices"]) == 1
assert result["choices"][0]["message"]["content"] == "Hello!Bye!"
assert result["choices"][0]["finish_reason"] == "stop"
assert result["usage"]["prompt_tokens"] == 5

def test_get_gcp_target_env(monkeypatch):
monkeypatch.setenv("PROXY_TARGET", "https://custom-proxy")
result = vertex.get_gcp_target("any-model", "/v1/chat/completions")
assert result == "https://custom-proxy"

def test_get_gcp_target_known_chat(monkeypatch):
monkeypatch.delenv("PROXY_TARGET", raising=False)
model = vertex.known_chat_models[0]
path = "/v1/chat/completions"
result = vertex.get_gcp_target(model, path)
assert "endpoints/openapi/chat/completions" in result

def test_get_gcp_target_raw_predict(monkeypatch):
monkeypatch.delenv("PROXY_TARGET", raising=False)
model = "unknown-model"
path = "/v1/other"
result = vertex.get_gcp_target(model, path)
assert ":rawPredict" in result

@patch("api.routers.vertex.get_access_token", return_value="dummy-token")
def test_get_header_removes_hop_headers(mock_token, dummy_request):
req = dummy_request(headers={
"Host": "example.com",
"Content-Length": "123",
"Accept-Encoding": "gzip",
"Connection": "keep-alive",
"Authorization": "Bearer old",
"X-Custom": "foo"
})
model = "test-model"
path = "/v1/chat/completions"
with patch("api.routers.vertex.get_gcp_target", return_value="http://target"):
target_url, headers = vertex.get_header(model, req, path)
assert target_url == "http://target"
assert "Host" not in headers
assert "Content-Length" not in headers
assert "Accept-Encoding" not in headers
assert "Connection" not in headers
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer dummy-token"
assert headers["x-custom"] == "foo"

@pytest.mark.asyncio
@patch("api.routers.vertex.httpx.AsyncClient")
@patch("api.routers.vertex.get_header")
@patch("api.routers.vertex.get_model", return_value="test-model")
async def test_handle_proxy_basic(mock_get_model, mock_get_header, mock_async_client, dummy_request):
req = dummy_request(body=json.dumps({"model": "foo"}).encode())
mock_get_header.return_value = ("http://target", {"Authorization": "Bearer token"})
mock_response = MagicMock()
mock_response.content = b'{"candidates":[{"content":{"parts":[{"text":"hi"}]}, "finishReason":"STOP"}]}'
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_async_client.return_value.__aenter__.return_value.request.return_value = mock_response

vertex.USE_MODEL_MAPPING = True
vertex.known_chat_models.append("test-model")
result = await vertex.handle_proxy(req, "/v1/chat/completions")
assert result.status_code == 200
assert b"hi" in result.body
assert result.headers["content-type"] == "application/json"

@pytest.mark.asyncio
@patch("api.routers.vertex.httpx.AsyncClient")
@patch("api.routers.vertex.get_header")
@patch("api.routers.vertex.get_model", return_value="test-model")
async def test_handle_proxy_known_chat_model(
mock_get_model, mock_get_header, mock_async_client, dummy_request
):
req = dummy_request(body=json.dumps({"model": "foo"}).encode())
mock_get_header.return_value = ("http://target", {"Authorization": "Bearer token"})
mock_response = MagicMock()
mock_response.content = b'{"candidates":[{"content":{"parts":[{"text":"hi"}]}, "finishReason":"STOP"}]}'
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_async_client.return_value.__aenter__.return_value.request.return_value = mock_response

vertex.USE_MODEL_MAPPING = True
if "test-model" not in vertex.known_chat_models:
vertex.known_chat_models.append("test-model")

result = await vertex.handle_proxy(req, "/v1/chat/completions")
assert isinstance(result, Response)
assert result.status_code == 200
assert b"hi" in result.body
assert result.headers["content-type"] == "application/json"

@pytest.mark.asyncio
@patch("api.routers.vertex.httpx.AsyncClient")
@patch("api.routers.vertex.get_header")
@patch("api.routers.vertex.get_model", return_value="anthropic-model")
async def test_handle_proxy_anthropic_conversion(
mock_get_model, mock_get_header, mock_async_client, dummy_request
):
req = dummy_request(body=json.dumps({"model": "foo", "messages": [{"role": "user", "content": "hi"}]}).encode())
mock_get_header.return_value = ("http://target", {"Authorization": "Bearer token"})
mock_response = MagicMock()
# Simulate anthropic response
anthropic_resp = json.dumps({
"id": "abc123",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"stop_reason": "stop",
"usage": {"prompt_tokens": 5, "completion_tokens": 2}
}).encode()
mock_response.content = anthropic_resp
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_async_client.return_value.__aenter__.return_value.request.return_value = mock_response

vertex.USE_MODEL_MAPPING = True
# Ensure model is not in known_chat_models to trigger conversion
if "anthropic-model" in vertex.known_chat_models:
vertex.known_chat_models.remove("anthropic-model")
result = await vertex.handle_proxy(req, "/v1/chat/completions")
assert isinstance(result, Response)
data = json.loads(result.body)
assert data["object"] == "chat.completion"
assert data["choices"][0]["message"]["content"] == "Hello!"

@pytest.mark.asyncio
@patch("api.routers.vertex.httpx.AsyncClient", side_effect=Exception("network error"))
@patch("api.routers.vertex.get_header")
@patch("api.routers.vertex.get_model", return_value="test-model")
async def test_handle_proxy_httpx_exception(
mock_get_model, mock_get_header, mock_async_client, dummy_request
):
req = dummy_request(body=json.dumps({"model": "foo"}).encode())
mock_get_header.return_value = ("http://target", {"Authorization": "Bearer token"})
vertex.USE_MODEL_MAPPING = True
if "test-model" not in vertex.known_chat_models:
vertex.known_chat_models.append("test-model")
# Patch httpx.RequestError to be raised
with patch("api.routers.vertex.httpx.RequestError", Exception):
result = await vertex.handle_proxy(req, "/v1/chat/completions")
assert isinstance(result, Response)
assert result.status_code == 502
assert b"Upstream request failed" in result.body
# Assert that the status code is 502 (Bad Gateway) due to upstream failure
assert result.status_code == 502

# Assert that the response body contains the expected error message
assert b"Upstream request failed" in result.body

Loading