Skip to content

Commit 323c54d

Browse files
fix(route_llm_request.py): map team model from list in route llm request
remove unnecessary proxy model table lookup for model alias Fixes issue where aliases weren't being consistently written to model table
1 parent 058503e commit 323c54d

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

litellm/proxy/route_llm_request.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ def __init__(self, route: str, model_name: str):
3333
super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
3434

3535

36+
def get_team_id_from_data(data: dict) -> Optional[str]:
37+
"""
38+
Get the team id from the data's metadata or litellm_metadata params.
39+
"""
40+
if (
41+
"metadata" in data
42+
and data["metadata"] is not None
43+
and "user_api_key_team_id" in data["metadata"]
44+
):
45+
return data["metadata"].get("user_api_key_team_id")
46+
elif (
47+
"litellm_metadata" in data
48+
and data["litellm_metadata"] is not None
49+
and "user_api_key_team_id" in data["litellm_metadata"]
50+
):
51+
return data["litellm_metadata"].get("user_api_key_team_id")
52+
return None
53+
54+
3655
async def route_request(
3756
data: dict,
3857
llm_router: Optional[LitellmRouter],
@@ -55,6 +74,7 @@ async def route_request(
5574
"""
5675
Common helper to route the request
5776
"""
77+
team_id = get_team_id_from_data(data)
5878
router_model_names = llm_router.model_names if llm_router is not None else []
5979
if "api_key" in data or "api_base" in data:
6080
return getattr(llm_router, f"{route_type}")(**data)
@@ -78,7 +98,16 @@ async def route_request(
7898
models = [model.strip() for model in data.pop("model").split(",")]
7999
return llm_router.abatch_completion(models=models, **data)
80100
elif llm_router is not None:
81-
if (
101+
team_model_name = (
102+
llm_router.map_team_model(data["model"], team_id)
103+
if team_id is not None
104+
else None
105+
)
106+
if team_model_name is not None:
107+
data["model"] = team_model_name
108+
return getattr(llm_router, f"{route_type}")(**data)
109+
110+
elif (
82111
data["model"] in router_model_names
83112
or data["model"] in llm_router.get_model_ids()
84113
):

litellm/router.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5407,6 +5407,26 @@ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
54075407
ids.append(id)
54085408
return ids
54095409

5410+
def map_team_model(self, team_model_name: str, team_id: str) -> Optional[str]:
5411+
"""
5412+
Map a team model name to a team-specific model name.
5413+
5414+
Returns:
5415+
- team_model_name: str - the team-specific model name
5416+
- None: if no team-specific model name is found
5417+
"""
5418+
for model in self.model_list:
5419+
model_team_id = model["model_info"].get("team_id")
5420+
model_team_public_model_name = model["model_info"].get(
5421+
"team_public_model_name"
5422+
)
5423+
if (
5424+
model_team_id == team_id
5425+
and model_team_public_model_name == team_model_name
5426+
):
5427+
return model["model_name"]
5428+
return None
5429+
54105430
def _get_all_deployments(
54115431
self, model_name: str, model_alias: Optional[str] = None
54125432
) -> List[DeploymentTypedDict]:

0 commit comments

Comments
 (0)