Skip to content

Commit fac6d32

Browse files
committed
review updates
1 parent 9c395b9 commit fac6d32

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

src/api/routers/test_vertex.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_from_anthropic_to_openai_response():
4545
"stop_reason": "stop",
4646
"usage": {"prompt_tokens": 5, "completion_tokens": 2}
4747
})
48-
result = json.loads(vertex.from_anthropic_to_openai_response(msg))
48+
result = json.loads(vertex.from_anthropic_to_openai_response(msg, "default"))
4949
assert result["id"] == "abc123"
5050
assert result["object"] == "chat.completion"
5151
assert len(result["choices"]) == 1
@@ -197,3 +197,26 @@ async def test_handle_proxy_httpx_exception(
197197
# Assert that the response body contains the expected error message
198198
assert b"Upstream request failed" in result.body
199199

200+
def test_get_chat_completion_model_name_known_chat_model():
201+
# Pick a known chat model from the list
202+
model_alias = "publishers/google/models/gemini-2.0-flash-lite-001"
203+
# Patch known_chat_models to ensure the model is present
204+
if model_alias not in vertex.known_chat_models:
205+
vertex.known_chat_models.append(model_alias)
206+
# Patch the function to use the correct argument name
207+
# The function as written has a bug: it uses 'model' instead of 'model_alias'
208+
# So we patch the function here for the test
209+
# But for now, test as is
210+
result = vertex.get_chat_completion_model_name(model_alias)
211+
# Should remove 'publishers/' and 'models/' from the string
212+
assert result == "google/gemini-2.0-flash-lite-001"
213+
214+
def test_get_chat_completion_model_name_unknown_model():
215+
model_alias = "some-other-model"
216+
# Ensure it's not in known_chat_models
217+
if model_alias in vertex.known_chat_models:
218+
vertex.known_chat_models.remove(model_alias)
219+
result = vertex.get_chat_completion_model_name(model_alias)
220+
# Should return the input unchanged
221+
assert result == model_alias
222+

src/api/routers/vertex.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_proxy_target(model, path):
7979
"""
8080
if os.getenv("PROXY_TARGET"):
8181
return os.getenv("PROXY_TARGET")
82-
elif model in known_chat_models and path.endswith("/chat/completions")
82+
elif model in known_chat_models and path.endswith("/chat/completions"):
8383
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/endpoints/openapi/chat/completions"
8484
else:
8585
return f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/{model}:rawPredict"
@@ -114,12 +114,12 @@ def to_vertex_anthropic(openai_messages):
114114
"messages": message
115115
}
116116

117-
def from_anthropic_to_openai_response(msg):
117+
def from_anthropic_to_openai_response(msg, model):
118118
msg_json = json.loads(msg)
119119
return json.dumps({
120120
"id": msg_json["id"],
121121
"object": "chat.completion",
122-
"model": msg_json.get("model", "claude"),
122+
"model": model,
123123
"choices": [
124124
{
125125
"index": 0,
@@ -136,21 +136,23 @@ def from_anthropic_to_openai_response(msg):
136136
"usage": msg_json.get("usage", {})
137137
})
138138

139+
def get_chat_completion_model_name(model_alias):
140+
if model_alias in known_chat_models:
141+
# publishers/google/models/gemini-2.0-flash-lite-001 -> "google/gemini-2.0-flash-lite-001"
142+
model_alias = model_alias.replace("publishers/", "").replace("models/", "")
143+
144+
return model_alias
145+
139146
async def handle_proxy(request: Request, path: str):
140147
try:
141148
content = await request.body()
142149
content_json = json.loads(content)
150+
model_alias = content_json.get("model", "default")
151+
model = get_model("gcp", model_alias)
143152

144153
if USE_MODEL_MAPPING:
145154
if "model" in content_json:
146-
request_model = content_json.get("model", None)
147-
model = get_model("gcp", request_model)
148-
model_name = model
149-
150-
if model != None and model != request_model and "publishers/google/" in model:
151-
model_name = f"google/{model.split('/')[-1]}"
152-
153-
content_json["model"]= model_name
155+
content_json["model"]= get_chat_completion_model_name(model)
154156

155157
needs_conversion = False
156158
if not model in known_chat_models:
@@ -175,7 +177,7 @@ async def handle_proxy(request: Request, path: str):
175177
if needs_conversion:
176178
# convert vertex response to openai format
177179
if "anthropic" in model:
178-
content = from_anthropic_to_openai_response(response.content)
180+
content = from_anthropic_to_openai_response(response.content, model_alias)
179181

180182
except httpx.RequestError as e:
181183
logging.error(f"Proxy request failed: {e}")

0 commit comments

Comments
 (0)