Skip to content

Commit 655ce2e

Browse files
authored
Merge pull request #9683 from BerriAI/litellm_fix_service_account_behavior
[Bug fix] - Service accounts - only apply `service_account_settings.enforced_params` on service accounts
2 parents 61b609f + 4ddca7a commit 655ce2e

File tree

6 files changed

+133
-104
lines changed

6 files changed

+133
-104
lines changed

litellm/proxy/auth/service_account_checks.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

litellm/proxy/auth/user_api_key_auth.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
5050
from litellm.proxy.auth.oauth2_check import check_oauth2_token
5151
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
52-
from litellm.proxy.auth.service_account_checks import service_account_checks
5352
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
5453
from litellm.proxy.utils import PrismaClient, ProxyLogging
5554
from litellm.types.services import ServiceTypes
@@ -905,12 +904,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
905904
else:
906905
_team_obj = None
907906

908-
# Check 7: Check if key is a service account key
909-
await service_account_checks(
910-
valid_token=valid_token,
911-
request_data=request_data,
912-
)
913-
914907
user_api_key_cache.set_cache(
915908
key=valid_token.team_id, value=_team_obj
916909
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py

litellm/proxy/litellm_pre_call_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,10 @@ def _get_enforced_params(
747747
enforced_params: Optional[list] = None
748748
if general_settings is not None:
749749
enforced_params = general_settings.get("enforced_params")
750-
if "service_account_settings" in general_settings:
750+
if (
751+
"service_account_settings" in general_settings
752+
and check_if_token_is_service_account(user_api_key_dict) is True
753+
):
751754
service_account_settings = general_settings["service_account_settings"]
752755
if "enforced_params" in service_account_settings:
753756
if enforced_params is None:
@@ -760,6 +763,20 @@ def _get_enforced_params(
760763
return enforced_params
761764

762765

766+
def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool:
767+
"""
768+
Checks if the token is a service account
769+
770+
Returns:
771+
bool: True if token is a service account
772+
773+
"""
774+
if valid_token.metadata:
775+
if "service_account_id" in valid_token.metadata:
776+
return True
777+
return False
778+
779+
763780
def _enforced_params_check(
764781
request_body: dict,
765782
general_settings: Optional[dict],

litellm/proxy/proxy_config.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,3 @@ model_list:
44
model: openai/fake
55
api_key: fake-key
66
api_base: https://exampleopenaiendpoint-production.up.railway.app/
7-
8-
general_settings:
9-
use_redis_transaction_buffer: true
10-
11-
litellm_settings:
12-
cache: True
13-
cache_params:
14-
type: redis
15-
supported_call_types: []
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import json
2+
import os
3+
import sys
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from litellm.proxy._types import UserAPIKeyAuth
9+
from litellm.proxy.litellm_pre_call_utils import (
10+
_get_enforced_params,
11+
check_if_token_is_service_account,
12+
)
13+
14+
sys.path.insert(
15+
0, os.path.abspath("../../..")
16+
) # Adds the parent directory to the system path
17+
18+
19+
def test_check_if_token_is_service_account():
20+
"""
21+
Test that only keys with `service_account_id` in metadata are considered service accounts
22+
"""
23+
# Test case 1: Service account token
24+
service_account_token = UserAPIKeyAuth(
25+
api_key="test-key", metadata={"service_account_id": "test-service-account"}
26+
)
27+
assert check_if_token_is_service_account(service_account_token) == True
28+
29+
# Test case 2: Regular user token
30+
regular_token = UserAPIKeyAuth(api_key="test-key", metadata={})
31+
assert check_if_token_is_service_account(regular_token) == False
32+
33+
# Test case 3: Token with other metadata
34+
other_metadata_token = UserAPIKeyAuth(
35+
api_key="test-key", metadata={"user_id": "test-user"}
36+
)
37+
assert check_if_token_is_service_account(other_metadata_token) == False
38+
39+
40+
def test_get_enforced_params_for_service_account_settings():
41+
"""
42+
Test that service account enforced params are only added to service account keys
43+
"""
44+
service_account_token = UserAPIKeyAuth(
45+
api_key="test-key", metadata={"service_account_id": "test-service-account"}
46+
)
47+
general_settings_with_service_account_settings = {
48+
"service_account_settings": {"enforced_params": ["metadata.service"]},
49+
}
50+
result = _get_enforced_params(
51+
general_settings=general_settings_with_service_account_settings,
52+
user_api_key_dict=service_account_token,
53+
)
54+
assert result == ["metadata.service"]
55+
56+
regular_token = UserAPIKeyAuth(
57+
api_key="test-key", metadata={"enforced_params": ["user"]}
58+
)
59+
result = _get_enforced_params(
60+
general_settings=general_settings_with_service_account_settings,
61+
user_api_key_dict=regular_token,
62+
)
63+
assert result == ["user"]
64+
65+
66+
@pytest.mark.parametrize(
67+
"general_settings, user_api_key_dict, expected_enforced_params",
68+
[
69+
(
70+
{"enforced_params": ["param1", "param2"]},
71+
UserAPIKeyAuth(
72+
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
73+
),
74+
["param1", "param2"],
75+
),
76+
(
77+
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
78+
UserAPIKeyAuth(
79+
api_key="test_api_key",
80+
user_id="test_user_id",
81+
org_id="test_org_id",
82+
metadata={"service_account_id": "test_service_account_id"},
83+
),
84+
["param1", "param2"],
85+
),
86+
(
87+
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
88+
UserAPIKeyAuth(
89+
api_key="test_api_key",
90+
metadata={
91+
"enforced_params": ["param3", "param4"],
92+
"service_account_id": "test_service_account_id",
93+
},
94+
),
95+
["param1", "param2", "param3", "param4"],
96+
),
97+
],
98+
)
99+
def test_get_enforced_params(
100+
general_settings, user_api_key_dict, expected_enforced_params
101+
):
102+
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
103+
104+
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
105+
assert enforced_params == expected_enforced_params

tests/proxy_unit_tests/test_proxy_utils.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -770,56 +770,31 @@ async def test_add_litellm_data_to_request_duplicate_tags(
770770

771771

772772
@pytest.mark.parametrize(
773-
"general_settings, user_api_key_dict, expected_enforced_params",
773+
"general_settings, user_api_key_dict, request_body, expected_error",
774774
[
775775
(
776776
{"enforced_params": ["param1", "param2"]},
777777
UserAPIKeyAuth(
778778
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
779779
),
780-
["param1", "param2"],
781-
),
782-
(
783-
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
784-
UserAPIKeyAuth(
785-
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
786-
),
787-
["param1", "param2"],
788-
),
789-
(
790-
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
791-
UserAPIKeyAuth(
792-
api_key="test_api_key",
793-
metadata={"enforced_params": ["param3", "param4"]},
794-
),
795-
["param1", "param2", "param3", "param4"],
780+
{},
781+
True,
796782
),
797-
],
798-
)
799-
def test_get_enforced_params(
800-
general_settings, user_api_key_dict, expected_enforced_params
801-
):
802-
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
803-
804-
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
805-
assert enforced_params == expected_enforced_params
806-
807-
808-
@pytest.mark.parametrize(
809-
"general_settings, user_api_key_dict, request_body, expected_error",
810-
[
811783
(
812-
{"enforced_params": ["param1", "param2"]},
784+
{"service_account_settings": {"enforced_params": ["user"]}},
813785
UserAPIKeyAuth(
814786
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
815787
),
816788
{},
817-
True,
789+
False,
818790
),
819791
(
820792
{"service_account_settings": {"enforced_params": ["user"]}},
821793
UserAPIKeyAuth(
822-
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
794+
api_key="test_api_key",
795+
user_id="test_user_id",
796+
org_id="test_org_id",
797+
metadata={"service_account_id": "test_service_account_id"}
823798
),
824799
{},
825800
True,
@@ -854,6 +829,7 @@ def test_get_enforced_params(
854829
{"service_account_settings": {"enforced_params": ["user"]}},
855830
UserAPIKeyAuth(
856831
api_key="test_api_key",
832+
metadata={"service_account_id": "test_service_account_id"}
857833
),
858834
{"user": "test_user"},
859835
False,

0 commit comments

Comments
 (0)