Skip to content

add model mapping for embeddings #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: defang
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ push: no-diff login
.PHONY: login
login: ## Login to docker
@docker login

.PHONY: tests
tests:
PYTHONPATH=src pytest
8 changes: 4 additions & 4 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from api.setting import API_ROUTE_PREFIX, DESCRIPTION, SUMMARY, PROVIDER, TITLE, USE_MODEL_MAPPING, VERSION
from api.modelmapper import load_model_map
from api.routers.vertex import handle_proxy
from api.routers.gcp.chat import handle_proxy

def is_aws():
env = os.getenv("AWS_EXECUTION_ENV")
Expand Down Expand Up @@ -54,10 +54,10 @@ def is_aws():
)

if provider != "aws":
from api.routers.gcp import chat, embeddings
logging.info(f"Proxy target set to: GCP")
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def proxy(request: Request, path: str):
return await handle_proxy(request, path)
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
else:
from api.routers import chat, embeddings, model
logging.info("No proxy target set. Using internal routers.")
Expand Down
48 changes: 48 additions & 0 deletions src/api/gcp/credentials/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import requests

from google.auth import default
from google.auth.transport.requests import Request as AuthRequest

from api.setting import GCP_PROJECT_ID, GCP_REGION


# GCP credentials and project details
credentials = None
project_id = None
location = None

def get_gcp_project_details():
from google.auth import default

# Try metadata server for region
credentials = None
project_id = GCP_PROJECT_ID
location = GCP_REGION

try:
credentials, project = default()
if not project_id:
project_id = project

if not location:
zone = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
headers={"Metadata-Flavor": "Google"},
timeout=1
).text
location = zone.split("/")[-1].rsplit("-", 1)[0]

except Exception:
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.")

return credentials, project_id, location

credentials, project_id, location = get_gcp_project_details()

# Utility: get service account access token
def get_access_token():
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = AuthRequest()
credentials.refresh(auth_request)
return credentials.token
2 changes: 2 additions & 0 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
UserMessage,
)
from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE
from modelmapper import get_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -868,6 +869,7 @@ def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:


def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_id = get_model("aws", model_id)
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
Expand Down
58 changes: 13 additions & 45 deletions src/api/routers/vertex.py → src/api/routers/gcp/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import json
import logging
import os
import requests
import uuid

from fastapi import Request, Response
from fastapi import APIRouter, Depends, Request, Response
from contextlib import asynccontextmanager
from api.setting import API_ROUTE_PREFIX, GCP_PROJECT_ID, GCP_REGION, USE_MODEL_MAPPING
from google.auth import default
from google.auth.transport.requests import Request as AuthRequest

from api.auth import api_key_auth
from api.modelmapper import get_model
from api.gcp.credentials.metadata import get_access_token, project_id, location
from api.schema import ChatResponse, ChatStreamResponse, Error

known_chat_models = [
"publishers/mistral-ai/models/mistral-7b-instruct-v0.3",
Expand All @@ -30,46 +31,10 @@
"publishers/meta/models/llama2-7b",
]


# GCP credentials and project details
credentials = None
project_id = None
location = None

def get_gcp_project_details():
from google.auth import default

# Try metadata server for region
credentials = None
project_id = GCP_PROJECT_ID
location = GCP_REGION

try:
credentials, project = default()
if not project_id:
project_id = project

if not location:
zone = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
headers={"Metadata-Flavor": "Google"},
timeout=1
).text
location = zone.split("/")[-1].rsplit("-", 1)[0]

except Exception:
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.")

return credentials, project_id, location

credentials, project_id, location = get_gcp_project_details()

# Utility: get service account access token
def get_access_token():
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = AuthRequest()
credentials.refresh(auth_request)
return credentials.token
router = APIRouter(
prefix="/chat",
dependencies=[Depends(api_key_auth)],
)

def get_proxy_target(model, path):
"""
Expand All @@ -82,7 +47,7 @@ def get_proxy_target(model, path):
else:
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{model}:rawPredict"

def get_headers(model, request, path):
def get_header(model, request, path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_header(model, request, path):
def get_headers(model, request, path):

path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX)
target_url = get_proxy_target(model, path_no_prefix)

Expand Down Expand Up @@ -140,6 +105,9 @@ def get_chat_completion_model_name(model_alias):

return model_alias.split('/')[-1]

@router.post(
"/completions", response_model=ChatResponse | ChatStreamResponse | Error, response_model_exclude_unset=True
)
async def handle_proxy(request: Request, path: str):
try:
content = await request.body()
Expand All @@ -159,7 +127,7 @@ async def handle_proxy(request: Request, path: str):
conversion_target = "anthropic"

# Build safe target URL
target_url, request_headers = get_headers(model, request, path)
target_url, request_headers = get_header(model, request, path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target_url, request_headers = get_header(model, request, path)
target_url, request_headers = get_headers(model, request, path)

async with httpx.AsyncClient() as client:
response = await client.request(
method=request.method,
Expand Down
122 changes: 122 additions & 0 deletions src/api/routers/gcp/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import httpx
import json
import logging
import os

from fastapi import APIRouter, Depends, Request, Response
from api.auth import api_key_auth
from api.schema import EmbeddingsResponse
from api.setting import API_ROUTE_PREFIX

from api.modelmapper import get_model
from api.gcp.credentials.metadata import get_access_token, project_id, location

router = APIRouter(
prefix="/embeddings",
dependencies=[Depends(api_key_auth)],
)
def get_proxy_target(model, path):
"""
Check if the environment variable is set to use GCP.
"""
if os.getenv("PROXY_TARGET"):
return os.getenv("PROXY_TARGET")
else:
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{model}:predict"

def get_header(model, request, path):
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX)
target_url = get_proxy_target(model, path_no_prefix)

# remove hop-by-hop headers
headers = {
k: v for k, v in request.headers.items()
if k.lower() not in {"host", "content-length", "accept-encoding", "connection", "authorization"}
}

# Fetch service account token
access_token = get_access_token()
headers["Authorization"] = f"Bearer {access_token}"
return target_url, headers

def to_vertex_embeddings(request):
"""
Convert OpenAI-style embeddings request to Vertex AI format.
"""
vertex_request = {
"instances": []
}

msg_input = request.get("input")
if type(msg_input) is str:
vertex_request["instances"] = [{
"content": f"{msg_input}"
}]
elif type(msg_input) is list:
vertex_request["instances"] = [{"content": f"{str(item)}"} for item in msg_input]

return vertex_request
Comment on lines +46 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can tighten this up a bit. What do you think of this?

Suggested change
vertex_request = {
"instances": []
}
msg_input = request.get("input")
if type(msg_input) is str:
vertex_request["instances"] = [{
"content": f"{msg_input}"
}]
elif type(msg_input) is list:
vertex_request["instances"] = [{"content": f"{str(item)}"} for item in msg_input]
return vertex_request
inputs = request.get("input", [])
if not isinstance(inputs, list):
inputs = [inputs]
return {
"instances": [{"content": str(content)} for content in inputs]
}


def to_openai_response(embedding_content, model):
"""
Convert Vertex AI embeddings response to OpenAI format.
"""
total_tokens = sum(
item["embeddings"]["statistics"]["token_count"]
for item in embedding_content.get("predictions", [])
)

return {
"data": [
{
"embedding": item["embeddings"]["values"],
"index": idx,
"object": "embedding",
} for idx, item in enumerate(embedding_content.get("predictions", []))
],
"model": model,
"object": "list",
"usage": {
"total_tokens": total_tokens
}
}

@router.post("/{path:path}", response_model=EmbeddingsResponse)
async def handle_proxy(request: Request, path: str):
try:
content = await request.body()
content_json = json.loads(content)
model_alias = content_json.get("model", "embedding-default")
model = get_model("gcp", model_alias)

# Build safe target URL
target_url, request_headers = get_header(model, request, path)
vertex_embedding_content = to_vertex_embeddings(content_json)
async with httpx.AsyncClient() as client:
response = await client.request(
method=request.method,
url=target_url,
headers=request_headers,
content=json.dumps(vertex_embedding_content),
params=request.query_params,
timeout=5.0,
)

content = to_openai_response(json.loads(response.content), model_alias)

except httpx.RequestError as e:
logging.error(f"Proxy request failed: {e}")
return Response(status_code=502, content=f"Upstream request failed: {e}")

# remove hop-by-hop headers
response_headers = {
k: v for k, v in response.headers.items()
if k.lower() not in {"content-encoding", "transfer-encoding", "connection"}
}

return Response(
content=json.dumps(content),
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type", "application/octet-stream"),
)
Loading