forked from aws-samples/bedrock-access-gateway
-
Notifications
You must be signed in to change notification settings - Fork 0
add model mapping for embeddings #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nullfunc
wants to merge
6
commits into
defang
Choose a base branch
from
eric-add-embeddings
base: defang
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
efa854a
add model mapping to aws embedding
nullfunc 4b596f5
fix tests
nullfunc b7f9e6b
working embedding translation
nullfunc 42716b0
add vertex format to openai embedding response
nullfunc dc68b73
update embeddings response
nullfunc fe69d90
add tests
nullfunc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import logging | ||
import requests | ||
|
||
from google.auth import default | ||
from google.auth.transport.requests import Request as AuthRequest | ||
|
||
from api.setting import GCP_PROJECT_ID, GCP_REGION | ||
|
||
|
||
# GCP credentials and project details | ||
credentials = None | ||
project_id = None | ||
location = None | ||
|
||
def get_gcp_project_details(): | ||
from google.auth import default | ||
|
||
# Try metadata server for region | ||
credentials = None | ||
project_id = GCP_PROJECT_ID | ||
location = GCP_REGION | ||
|
||
try: | ||
credentials, project = default() | ||
if not project_id: | ||
project_id = project | ||
|
||
if not location: | ||
zone = requests.get( | ||
"http://metadata.google.internal/computeMetadata/v1/instance/zone", | ||
headers={"Metadata-Flavor": "Google"}, | ||
timeout=1 | ||
).text | ||
location = zone.split("/")[-1].rsplit("-", 1)[0] | ||
|
||
except Exception: | ||
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.") | ||
|
||
return credentials, project_id, location | ||
|
||
credentials, project_id, location = get_gcp_project_details() | ||
|
||
# Utility: get service account access token | ||
def get_access_token(): | ||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) | ||
auth_request = AuthRequest() | ||
credentials.refresh(auth_request) | ||
return credentials.token |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,16 +2,17 @@ | |||||
import json | ||||||
import logging | ||||||
import os | ||||||
import requests | ||||||
import uuid | ||||||
|
||||||
from fastapi import Request, Response | ||||||
from fastapi import APIRouter, Depends, Request, Response | ||||||
from contextlib import asynccontextmanager | ||||||
from api.setting import API_ROUTE_PREFIX, GCP_PROJECT_ID, GCP_REGION, USE_MODEL_MAPPING | ||||||
from google.auth import default | ||||||
from google.auth.transport.requests import Request as AuthRequest | ||||||
|
||||||
from api.auth import api_key_auth | ||||||
from api.modelmapper import get_model | ||||||
from api.gcp.credentials.metadata import get_access_token, project_id, location | ||||||
from api.schema import ChatResponse, ChatStreamResponse, Error | ||||||
|
||||||
known_chat_models = [ | ||||||
"publishers/mistral-ai/models/mistral-7b-instruct-v0.3", | ||||||
|
@@ -30,46 +31,10 @@ | |||||
"publishers/meta/models/llama2-7b", | ||||||
] | ||||||
|
||||||
|
||||||
# GCP credentials and project details | ||||||
credentials = None | ||||||
project_id = None | ||||||
location = None | ||||||
|
||||||
def get_gcp_project_details(): | ||||||
from google.auth import default | ||||||
|
||||||
# Try metadata server for region | ||||||
credentials = None | ||||||
project_id = GCP_PROJECT_ID | ||||||
location = GCP_REGION | ||||||
|
||||||
try: | ||||||
credentials, project = default() | ||||||
if not project_id: | ||||||
project_id = project | ||||||
|
||||||
if not location: | ||||||
zone = requests.get( | ||||||
"http://metadata.google.internal/computeMetadata/v1/instance/zone", | ||||||
headers={"Metadata-Flavor": "Google"}, | ||||||
timeout=1 | ||||||
).text | ||||||
location = zone.split("/")[-1].rsplit("-", 1)[0] | ||||||
|
||||||
except Exception: | ||||||
logging.warning(f"Error: Failed to get project and location from metadata server. Using local settings.") | ||||||
|
||||||
return credentials, project_id, location | ||||||
|
||||||
credentials, project_id, location = get_gcp_project_details() | ||||||
|
||||||
# Utility: get service account access token | ||||||
def get_access_token(): | ||||||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) | ||||||
auth_request = AuthRequest() | ||||||
credentials.refresh(auth_request) | ||||||
return credentials.token | ||||||
router = APIRouter( | ||||||
prefix="/chat", | ||||||
dependencies=[Depends(api_key_auth)], | ||||||
) | ||||||
|
||||||
def get_proxy_target(model, path): | ||||||
""" | ||||||
|
@@ -82,7 +47,7 @@ def get_proxy_target(model, path): | |||||
else: | ||||||
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{model}:rawPredict" | ||||||
|
||||||
def get_headers(model, request, path): | ||||||
def get_header(model, request, path): | ||||||
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX) | ||||||
target_url = get_proxy_target(model, path_no_prefix) | ||||||
|
||||||
|
@@ -140,6 +105,9 @@ def get_chat_completion_model_name(model_alias): | |||||
|
||||||
return model_alias.split('/')[-1] | ||||||
|
||||||
@router.post( | ||||||
"/completions", response_model=ChatResponse | ChatStreamResponse | Error, response_model_exclude_unset=True | ||||||
) | ||||||
async def handle_proxy(request: Request, path: str): | ||||||
try: | ||||||
content = await request.body() | ||||||
|
@@ -159,7 +127,7 @@ async def handle_proxy(request: Request, path: str): | |||||
conversion_target = "anthropic" | ||||||
|
||||||
# Build safe target URL | ||||||
target_url, request_headers = get_headers(model, request, path) | ||||||
target_url, request_headers = get_header(model, request, path) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
async with httpx.AsyncClient() as client: | ||||||
response = await client.request( | ||||||
method=request.method, | ||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,122 @@ | ||||||||||||||||||||||||||||||||||||||||
import httpx | ||||||||||||||||||||||||||||||||||||||||
import json | ||||||||||||||||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from fastapi import APIRouter, Depends, Request, Response | ||||||||||||||||||||||||||||||||||||||||
from api.auth import api_key_auth | ||||||||||||||||||||||||||||||||||||||||
from api.schema import EmbeddingsResponse | ||||||||||||||||||||||||||||||||||||||||
from api.setting import API_ROUTE_PREFIX | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from api.modelmapper import get_model | ||||||||||||||||||||||||||||||||||||||||
from api.gcp.credentials.metadata import get_access_token, project_id, location | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
router = APIRouter( | ||||||||||||||||||||||||||||||||||||||||
prefix="/embeddings", | ||||||||||||||||||||||||||||||||||||||||
dependencies=[Depends(api_key_auth)], | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
def get_proxy_target(model, path): | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
Check if the environment variable is set to use GCP. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
if os.getenv("PROXY_TARGET"): | ||||||||||||||||||||||||||||||||||||||||
return os.getenv("PROXY_TARGET") | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{model}:predict" | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def get_header(model, request, path): | ||||||||||||||||||||||||||||||||||||||||
path_no_prefix = f"/{path.lstrip('/')}".removeprefix(API_ROUTE_PREFIX) | ||||||||||||||||||||||||||||||||||||||||
target_url = get_proxy_target(model, path_no_prefix) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# remove hop-by-hop headers | ||||||||||||||||||||||||||||||||||||||||
headers = { | ||||||||||||||||||||||||||||||||||||||||
k: v for k, v in request.headers.items() | ||||||||||||||||||||||||||||||||||||||||
if k.lower() not in {"host", "content-length", "accept-encoding", "connection", "authorization"} | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Fetch service account token | ||||||||||||||||||||||||||||||||||||||||
access_token = get_access_token() | ||||||||||||||||||||||||||||||||||||||||
headers["Authorization"] = f"Bearer {access_token}" | ||||||||||||||||||||||||||||||||||||||||
return target_url, headers | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def to_vertex_embeddings(request): | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
Convert OpenAI-style embeddings request to Vertex AI format. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
vertex_request = { | ||||||||||||||||||||||||||||||||||||||||
"instances": [] | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
msg_input = request.get("input") | ||||||||||||||||||||||||||||||||||||||||
if type(msg_input) is str: | ||||||||||||||||||||||||||||||||||||||||
vertex_request["instances"] = [{ | ||||||||||||||||||||||||||||||||||||||||
"content": f"{msg_input}" | ||||||||||||||||||||||||||||||||||||||||
}] | ||||||||||||||||||||||||||||||||||||||||
elif type(msg_input) is list: | ||||||||||||||||||||||||||||||||||||||||
vertex_request["instances"] = [{"content": f"{str(item)}"} for item in msg_input] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
return vertex_request | ||||||||||||||||||||||||||||||||||||||||
Comment on lines
+46
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can tighten this up a bit. What do you think of this?
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def to_openai_response(embedding_content, model): | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
Convert Vertex AI embeddings response to OpenAI format. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
total_tokens = sum( | ||||||||||||||||||||||||||||||||||||||||
item["embeddings"]["statistics"]["token_count"] | ||||||||||||||||||||||||||||||||||||||||
for item in embedding_content.get("predictions", []) | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
return { | ||||||||||||||||||||||||||||||||||||||||
"data": [ | ||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||
"embedding": item["embeddings"]["values"], | ||||||||||||||||||||||||||||||||||||||||
"index": idx, | ||||||||||||||||||||||||||||||||||||||||
"object": "embedding", | ||||||||||||||||||||||||||||||||||||||||
} for idx, item in enumerate(embedding_content.get("predictions", [])) | ||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||
"model": model, | ||||||||||||||||||||||||||||||||||||||||
"object": "list", | ||||||||||||||||||||||||||||||||||||||||
"usage": { | ||||||||||||||||||||||||||||||||||||||||
"total_tokens": total_tokens | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
@router.post("/{path:path}", response_model=EmbeddingsResponse) | ||||||||||||||||||||||||||||||||||||||||
async def handle_proxy(request: Request, path: str): | ||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||
content = await request.body() | ||||||||||||||||||||||||||||||||||||||||
content_json = json.loads(content) | ||||||||||||||||||||||||||||||||||||||||
model_alias = content_json.get("model", "embedding-default") | ||||||||||||||||||||||||||||||||||||||||
model = get_model("gcp", model_alias) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Build safe target URL | ||||||||||||||||||||||||||||||||||||||||
target_url, request_headers = get_header(model, request, path) | ||||||||||||||||||||||||||||||||||||||||
vertex_embedding_content = to_vertex_embeddings(content_json) | ||||||||||||||||||||||||||||||||||||||||
async with httpx.AsyncClient() as client: | ||||||||||||||||||||||||||||||||||||||||
response = await client.request( | ||||||||||||||||||||||||||||||||||||||||
method=request.method, | ||||||||||||||||||||||||||||||||||||||||
url=target_url, | ||||||||||||||||||||||||||||||||||||||||
headers=request_headers, | ||||||||||||||||||||||||||||||||||||||||
content=json.dumps(vertex_embedding_content), | ||||||||||||||||||||||||||||||||||||||||
params=request.query_params, | ||||||||||||||||||||||||||||||||||||||||
timeout=5.0, | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
content = to_openai_response(json.loads(response.content), model_alias) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
except httpx.RequestError as e: | ||||||||||||||||||||||||||||||||||||||||
logging.error(f"Proxy request failed: {e}") | ||||||||||||||||||||||||||||||||||||||||
return Response(status_code=502, content=f"Upstream request failed: {e}") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# remove hop-by-hop headers | ||||||||||||||||||||||||||||||||||||||||
response_headers = { | ||||||||||||||||||||||||||||||||||||||||
k: v for k, v in response.headers.items() | ||||||||||||||||||||||||||||||||||||||||
if k.lower() not in {"content-encoding", "transfer-encoding", "connection"} | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
return Response( | ||||||||||||||||||||||||||||||||||||||||
content=json.dumps(content), | ||||||||||||||||||||||||||||||||||||||||
status_code=response.status_code, | ||||||||||||||||||||||||||||||||||||||||
headers=response_headers, | ||||||||||||||||||||||||||||||||||||||||
media_type=response.headers.get("content-type", "application/octet-stream"), | ||||||||||||||||||||||||||||||||||||||||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.