1
1
import logging
2
- import requests
2
+ import os
3
3
import uvicorn
4
- from fastapi import FastAPI , Request , Response
4
+
5
+ from fastapi import FastAPI , Request
5
6
from fastapi .exceptions import RequestValidationError
6
7
from fastapi .middleware .cors import CORSMiddleware
7
8
from fastapi .responses import PlainTextResponse
8
9
from mangum import Mangum
9
- import httpx
10
- import json
11
- import os
12
- from contextlib import asynccontextmanager
13
-
14
- from api .setting import API_ROUTE_PREFIX , DESCRIPTION , GCP_PROJECT_ID , GCP_REGION , SUMMARY , PROVIDER , TITLE , USE_MODEL_MAPPING , VERSION
15
10
16
- from google .auth import default
17
- from google .auth .transport .requests import Request as AuthRequest
18
-
19
- from api .modelmapper import get_model , load_model_map
20
-
21
- # GCP credentials and project details
22
- credentials = None
23
- project_id = None
24
- location = None
11
+ from api .setting import API_ROUTE_PREFIX , DESCRIPTION , SUMMARY , PROVIDER , TITLE , USE_MODEL_MAPPING , VERSION
12
+ from api .modelmapper import load_model_map
13
+ from api .routers .vertex import handle_proxy
25
14
26
15
def is_aws ():
27
16
env = os .getenv ("AWS_EXECUTION_ENV" )
@@ -43,115 +32,6 @@ def is_aws():
43
32
if USE_MODEL_MAPPING :
44
33
load_model_map ()
45
34
46
-
47
- def get_gcp_project_details ():
48
- from google .auth import default
49
-
50
- # Try metadata server for region
51
- credentials = None
52
- project_id = GCP_PROJECT_ID
53
- location = GCP_REGION
54
-
55
- try :
56
- credentials , project = default ()
57
- if not project_id :
58
- project_id = project
59
-
60
- if not location :
61
- zone = requests .get (
62
- "http://metadata.google.internal/computeMetadata/v1/instance/zone" ,
63
- headers = {"Metadata-Flavor" : "Google" },
64
- timeout = 1
65
- ).text
66
- location = zone .split ("/" )[- 1 ].rsplit ("-" , 1 )[0 ]
67
-
68
- except Exception :
69
- logging .warning (f"Error: Failed to get project and location from metadata server. Using local settings." )
70
-
71
- return credentials , project_id , location
72
-
73
- if not is_aws ():
74
- credentials , project_id , location = get_gcp_project_details ()
75
-
76
- # Utility: get service account access token
77
- def get_access_token ():
78
- credentials , _ = default (scopes = ["https://www.googleapis.com/auth/cloud-platform" ])
79
- auth_request = AuthRequest ()
80
- credentials .refresh (auth_request )
81
- return credentials .token
82
-
83
- def get_gcp_target (path ):
84
- """
85
- Check if the environment variable is set to use GCP.
86
- """
87
- if os .getenv ("PROXY_TARGET" ):
88
- return os .getenv ("PROXY_TARGET" )
89
- else :
90
- return f"https://{ location } -aiplatform.googleapis.com/v1/projects/{ project_id } /locations/{ location } /{ path .lstrip ('/' )} " .rstrip ("/" )
91
-
92
- def get_header (request , path ):
93
- if "chat/completions" in path :
94
- path = path .replace ("chat/completions" , "endpoints/openapi/chat/completions" )
95
-
96
- path_no_prefix = f"/{ path .lstrip ('/' )} " .removeprefix (API_ROUTE_PREFIX )
97
- target_url = get_gcp_target (path_no_prefix )
98
-
99
- # remove hop-by-hop headers
100
- headers = {
101
- k : v for k , v in request .headers .items ()
102
- if k .lower () not in {"host" , "content-length" , "accept-encoding" , "connection" , "authorization" }
103
- }
104
-
105
- # Fetch service account token
106
- access_token = get_access_token ()
107
- headers ["Authorization" ] = f"Bearer { access_token } "
108
- return target_url ,headers
109
-
110
- async def handle_proxy (request : Request , path : str ):
111
- # Build safe target URL
112
- target_url , headers = get_header (request , path )
113
-
114
- try :
115
- content = await request .body ()
116
-
117
- if USE_MODEL_MAPPING :
118
- data = json .loads (content )
119
- if "model" in data :
120
- request_model = data .get ("model" , None )
121
- model = get_model ("gcp" , request_model )
122
-
123
- if model != None and model != request_model and "publishers/google/" in model :
124
- model = f"google/{ model .split ('/' )[- 1 ]} "
125
-
126
- data ["model" ]= model
127
- content = json .dumps (data )
128
-
129
- async with httpx .AsyncClient () as client :
130
- response = await client .request (
131
- method = request .method ,
132
- url = target_url ,
133
- headers = headers ,
134
- content = content ,
135
- params = request .query_params ,
136
- timeout = 30.0 ,
137
- )
138
- except httpx .RequestError as e :
139
- logging .error (f"Proxy request failed: { e } " )
140
- return Response (status_code = 502 , content = f"Upstream request failed: { e } " )
141
-
142
- # remove hop-by-hop headers
143
- response_headers = {
144
- k : v for k , v in response .headers .items ()
145
- if k .lower () not in {"content-encoding" , "transfer-encoding" , "connection" }
146
- }
147
-
148
- return Response (
149
- content = response .content ,
150
- status_code = response .status_code ,
151
- headers = response_headers ,
152
- media_type = response .headers .get ("content-type" , "application/octet-stream" ),
153
- )
154
-
155
35
config = {
156
36
"title" : TITLE ,
157
37
"description" : DESCRIPTION ,
0 commit comments