diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6e242ddacbe8..17b4acd138c8 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -144,6 +144,21 @@ class LitellmTableNames(str, enum.Enum): PROXY_MODEL_TABLE_NAME = "LiteLLM_ProxyModelTable" +class Litellm_EntityType(enum.Enum): + """ + Enum for types of entities on litellm + + This enum allows specifying the type of entity that is being tracked in the database. + """ + + KEY = "key" + USER = "user" + END_USER = "end_user" + TEAM = "team" + TEAM_MEMBER = "team_member" + ORGANIZATION = "organization" + + def hash_token(token: str): import hashlib diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py new file mode 100644 index 000000000000..34e8eae1734c --- /dev/null +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -0,0 +1,348 @@ +""" +Module responsible for + +1. Writing spend increments to either in memory list of transactions or to redis +2. Reading increments from redis or in memory list of transactions and committing them to db +""" + +import asyncio +import os +import traceback +from datetime import datetime +from typing import Any, Optional, Union + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.proxy._types import Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload +from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload +from litellm.proxy.utils import PrismaClient, ProxyUpdateSpend, hash_token + + +class DBSpendUpdateWriter: + """ + Module responsible for + + 1. Writing spend increments to either in memory list of transactions or to redis + 2. Reading increments from redis or in memory list of transactions and committing them to db + """ + + @staticmethod + async def update_database( + # LiteLLM management object fields + token: Optional[str], + user_id: Optional[str], + end_user_id: Optional[str], + team_id: Optional[str], + org_id: Optional[str], + # Completion object fields + kwargs: Optional[dict], + completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]], + start_time: Optional[datetime], + end_time: Optional[datetime], + response_cost: Optional[float], + ): + from litellm.proxy.proxy_server import ( + disable_spend_logs, + litellm_proxy_budget_name, + prisma_client, + user_api_key_cache, + ) + + try: + verbose_proxy_logger.debug( + f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" + ) + if ProxyUpdateSpend.disable_spend_updates() is True: + return + if token is not None and isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + + asyncio.create_task( + DBSpendUpdateWriter._update_user_db( + response_cost=response_cost, + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_proxy_budget_name=litellm_proxy_budget_name, + end_user_id=end_user_id, + ) + ) + asyncio.create_task( + DBSpendUpdateWriter._update_key_db( + response_cost=response_cost, + hashed_token=hashed_token, + prisma_client=prisma_client, + ) + ) + asyncio.create_task( + DBSpendUpdateWriter._update_team_db( + response_cost=response_cost, + team_id=team_id, + user_id=user_id, + prisma_client=prisma_client, + ) + ) + asyncio.create_task( + DBSpendUpdateWriter._update_org_db( + response_cost=response_cost, + org_id=org_id, + prisma_client=prisma_client, + ) + ) + if disable_spend_logs is False: + await DBSpendUpdateWriter._insert_spend_log_to_db( + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + response_cost=response_cost, + prisma_client=prisma_client, + ) + else: + verbose_proxy_logger.info( + "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." + ) + + verbose_proxy_logger.debug("Runs spend update on all tables") + except Exception: + verbose_proxy_logger.debug( + f"Error updating Prisma database: {traceback.format_exc()}" + ) + + @staticmethod + async def _update_transaction_list( + response_cost: Optional[float], + entity_id: Optional[str], + transaction_list: dict, + entity_type: Litellm_EntityType, + debug_msg: Optional[str] = None, + ) -> bool: + """ + Common helper method to update a transaction list for an entity + + Args: + response_cost: The cost to add + entity_id: The ID of the entity to update + transaction_list: The transaction list dictionary to update + entity_type: The type of entity (from EntityType enum) + debug_msg: Optional custom debug message + + Returns: + bool: True if update happened, False otherwise + """ + try: + if debug_msg: + verbose_proxy_logger.debug(debug_msg) + else: + verbose_proxy_logger.debug( + f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}." + ) + + if entity_id is None: + verbose_proxy_logger.debug( + f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}" + ) + return False + + transaction_list[entity_id] = response_cost + transaction_list.get( + entity_id, 0 + ) + return True + + except Exception as e: + verbose_proxy_logger.info( + f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + @staticmethod + async def _update_key_db( + response_cost: Optional[float], + hashed_token: Optional[str], + prisma_client: Optional[PrismaClient], + ): + try: + if hashed_token is None or prisma_client is None: + return + + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=hashed_token, + transaction_list=prisma_client.key_list_transactons, + entity_type=Litellm_EntityType.KEY, + debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.", + ) + except Exception as e: + verbose_proxy_logger.exception( + f"Update Key DB Call failed to execute - {str(e)}" + ) + raise e + + @staticmethod + async def _update_user_db( + response_cost: Optional[float], + user_id: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + litellm_proxy_budget_name: Optional[str], + end_user_id: Optional[str] = None, + ): + """ + - Update that user's row + - Update litellm-proxy-budget row (global proxy spend) + """ + ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db + existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if existing_user_obj is not None and isinstance(existing_user_obj, dict): + existing_user_obj = LiteLLM_UserTable(**existing_user_obj) + try: + if prisma_client is not None: # update + user_ids = [user_id] + if ( + litellm.max_budget > 0 + ): # track global proxy budget, if user set max budget + user_ids.append(litellm_proxy_budget_name) + + for _id in user_ids: + if _id is not None: + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=_id, + transaction_list=prisma_client.user_list_transactons, + entity_type=Litellm_EntityType.USER, + ) + + if end_user_id is not None: + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=end_user_id, + transaction_list=prisma_client.end_user_list_transactons, + entity_type=Litellm_EntityType.END_USER, + ) + except Exception as e: + verbose_proxy_logger.info( + "\033[91m" + + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" + ) + + @staticmethod + async def _update_team_db( + response_cost: Optional[float], + team_id: Optional[str], + user_id: Optional[str], + prisma_client: Optional[PrismaClient], + ): + try: + if team_id is None or prisma_client is None: + verbose_proxy_logger.debug( + "track_cost_callback: team_id is None or prisma_client is None. Not tracking spend for team" + ) + return + + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=team_id, + transaction_list=prisma_client.team_list_transactons, + entity_type=Litellm_EntityType.TEAM, + ) + + try: + # Track spend of the team member within this team + if user_id is not None: + # key is "team_id::::user_id::" + team_member_key = f"team_id::{team_id}::user_id::{user_id}" + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=team_member_key, + transaction_list=prisma_client.team_member_list_transactons, + entity_type=Litellm_EntityType.TEAM_MEMBER, + ) + except Exception: + pass + except Exception as e: + verbose_proxy_logger.info( + f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + @staticmethod + async def _update_org_db( + response_cost: Optional[float], + org_id: Optional[str], + prisma_client: Optional[PrismaClient], + ): + try: + if org_id is None or prisma_client is None: + verbose_proxy_logger.debug( + "track_cost_callback: org_id is None or prisma_client is None. Not tracking spend for org" + ) + return + + await DBSpendUpdateWriter._update_transaction_list( + response_cost=response_cost, + entity_id=org_id, + transaction_list=prisma_client.org_list_transactons, + entity_type=Litellm_EntityType.ORGANIZATION, + ) + except Exception as e: + verbose_proxy_logger.info( + f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + @staticmethod + async def _insert_spend_log_to_db( + kwargs: Optional[dict], + completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]], + start_time: Optional[datetime], + end_time: Optional[datetime], + response_cost: Optional[float], + prisma_client: Optional[PrismaClient], + ): + try: + if prisma_client: + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) + payload["spend"] = response_cost or 0.0 + DBSpendUpdateWriter._set_spend_logs_payload( + payload=payload, + spend_logs_url=os.getenv("SPEND_LOGS_URL"), + prisma_client=prisma_client, + ) + except Exception as e: + verbose_proxy_logger.debug( + f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + @staticmethod + def _set_spend_logs_payload( + payload: Union[dict, SpendLogsPayload], + prisma_client: PrismaClient, + spend_logs_url: Optional[str] = None, + ) -> PrismaClient: + verbose_proxy_logger.info( + "Writing spend log to db - request_id: {}, spend: {}".format( + payload.get("request_id"), payload.get("spend") + ) + ) + if prisma_client is not None and spend_logs_url is not None: + if isinstance(payload["startTime"], datetime): + payload["startTime"] = payload["startTime"].isoformat() + if isinstance(payload["endTime"], datetime): + payload["endTime"] = payload["endTime"].isoformat() + prisma_client.spend_log_transactions.append(payload) + elif prisma_client is not None: + prisma_client.spend_log_transactions.append(payload) + + prisma_client.add_spend_log_transaction_to_daily_user_transaction( + payload.copy() + ) + return prisma_client diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index e8a947329d79..f205b0146fe1 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -13,6 +13,7 @@ from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.utils import ProxyUpdateSpend from litellm.types.utils import ( StandardLoggingPayload, @@ -33,8 +34,6 @@ async def async_post_call_failure_hook( original_exception: Exception, user_api_key_dict: UserAPIKeyAuth, ): - from litellm.proxy.proxy_server import update_database - if _ProxyDBLogger._should_track_errors_in_db() is False: return @@ -67,7 +66,7 @@ async def async_post_call_failure_hook( request_data.get("proxy_server_request") or {} ) request_data["litellm_params"]["metadata"] = existing_metadata - await update_database( + await DBSpendUpdateWriter.update_database( token=user_api_key_dict.api_key, response_cost=0.0, user_id=user_api_key_dict.user_id, @@ -94,7 +93,6 @@ async def _PROXY_track_cost_callback( prisma_client, proxy_logging_obj, update_cache, - update_database, ) verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") @@ -138,7 +136,7 @@ async def _PROXY_track_cost_callback( end_user_id=end_user_id, ): ## UPDATE DATABASE - await update_database( + await DBSpendUpdateWriter.update_database( token=user_api_key, response_cost=response_cost, user_id=user_id, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5d7e92fd7357..d7e62f98d0e2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -897,211 +897,6 @@ def cost_tracking(): litellm.logging_callback_manager.add_litellm_callback(_ProxyDBLogger()) -def _set_spend_logs_payload( - payload: Union[dict, SpendLogsPayload], - prisma_client: PrismaClient, - spend_logs_url: Optional[str] = None, -): - verbose_proxy_logger.info( - "Writing spend log to db - request_id: {}, spend: {}".format( - payload.get("request_id"), payload.get("spend") - ) - ) - if prisma_client is not None and spend_logs_url is not None: - if isinstance(payload["startTime"], datetime): - payload["startTime"] = payload["startTime"].isoformat() - if isinstance(payload["endTime"], datetime): - payload["endTime"] = payload["endTime"].isoformat() - prisma_client.spend_log_transactions.append(payload) - elif prisma_client is not None: - prisma_client.spend_log_transactions.append(payload) - - prisma_client.add_spend_log_transaction_to_daily_user_transaction(payload.copy()) - return prisma_client - - -async def update_database( # noqa: PLR0915 - token, - response_cost, - user_id=None, - end_user_id=None, - team_id=None, - kwargs=None, - completion_response=None, - start_time=None, - end_time=None, - org_id=None, -): - try: - global prisma_client - verbose_proxy_logger.debug( - f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" - ) - if ProxyUpdateSpend.disable_spend_updates() is True: - return - if token is not None and isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token - - ### UPDATE USER SPEND ### - async def _update_user_db(): - """ - - Update that user's row - - Update litellm-proxy-budget row (global proxy spend) - """ - ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db - existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) - if existing_user_obj is not None and isinstance(existing_user_obj, dict): - existing_user_obj = LiteLLM_UserTable(**existing_user_obj) - try: - if prisma_client is not None: # update - user_ids = [user_id] - if ( - litellm.max_budget > 0 - ): # track global proxy budget, if user set max budget - user_ids.append(litellm_proxy_budget_name) - ### KEY CHANGE ### - for _id in user_ids: - if _id is not None: - prisma_client.user_list_transactons[_id] = ( - response_cost - + prisma_client.user_list_transactons.get(_id, 0) - ) - if end_user_id is not None: - prisma_client.end_user_list_transactons[end_user_id] = ( - response_cost - + prisma_client.end_user_list_transactons.get( - end_user_id, 0 - ) - ) - except Exception as e: - verbose_proxy_logger.info( - "\033[91m" - + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" - ) - - ### UPDATE KEY SPEND ### - async def _update_key_db(): - try: - verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." - ) - if hashed_token is None: - return - if prisma_client is not None: - prisma_client.key_list_transactons[hashed_token] = ( - response_cost - + prisma_client.key_list_transactons.get(hashed_token, 0) - ) - except Exception as e: - verbose_proxy_logger.exception( - f"Update Key DB Call failed to execute - {str(e)}" - ) - raise e - - ### UPDATE SPEND LOGS ### - async def _insert_spend_log_to_db(): - try: - global prisma_client - if prisma_client is not None: - # Helper to generate payload to log - payload = get_logging_payload( - kwargs=kwargs, - response_obj=completion_response, - start_time=start_time, - end_time=end_time, - ) - payload["spend"] = response_cost - prisma_client = _set_spend_logs_payload( - payload=payload, - spend_logs_url=os.getenv("SPEND_LOGS_URL"), - prisma_client=prisma_client, - ) - except Exception as e: - verbose_proxy_logger.debug( - f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - ### UPDATE TEAM SPEND ### - async def _update_team_db(): - try: - verbose_proxy_logger.debug( - f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}." - ) - if team_id is None: - verbose_proxy_logger.debug( - "track_cost_callback: team_id is None. Not tracking spend for team" - ) - return - if prisma_client is not None: - prisma_client.team_list_transactons[team_id] = ( - response_cost - + prisma_client.team_list_transactons.get(team_id, 0) - ) - - try: - # Track spend of the team member within this team - # key is "team_id::::user_id::" - team_member_key = f"team_id::{team_id}::user_id::{user_id}" - prisma_client.team_member_list_transactons[team_member_key] = ( - response_cost - + prisma_client.team_member_list_transactons.get( - team_member_key, 0 - ) - ) - except Exception: - pass - except Exception as e: - verbose_proxy_logger.info( - f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - ### UPDATE ORG SPEND ### - async def _update_org_db(): - try: - verbose_proxy_logger.debug( - "adding spend to org db. Response cost: {}. org_id: {}.".format( - response_cost, org_id - ) - ) - if org_id is None: - verbose_proxy_logger.debug( - "track_cost_callback: org_id is None. Not tracking spend for org" - ) - return - if prisma_client is not None: - prisma_client.org_list_transactons[org_id] = ( - response_cost - + prisma_client.org_list_transactons.get(org_id, 0) - ) - except Exception as e: - verbose_proxy_logger.info( - f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - asyncio.create_task(_update_user_db()) - asyncio.create_task(_update_key_db()) - asyncio.create_task(_update_team_db()) - asyncio.create_task(_update_org_db()) - # asyncio.create_task(_insert_spend_log_to_db()) - if disable_spend_logs is False: - await _insert_spend_log_to_db() - else: - verbose_proxy_logger.info( - "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." - ) - - verbose_proxy_logger.debug("Runs spend update on all tables") - except Exception: - verbose_proxy_logger.debug( - f"Error updating Prisma database: {traceback.format_exc()}" - ) - - async def update_cache( # noqa: PLR0915 token: Optional[str], user_id: Optional[str], diff --git a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py index 1e3b22ae2d3b..8850436329f0 100644 --- a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py +++ b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py @@ -47,7 +47,8 @@ async def test_async_post_call_failure_hook(): # Mock update_database function with patch( - "litellm.proxy.proxy_server.update_database", new_callable=AsyncMock + "litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database", + new_callable=AsyncMock, ) as mock_update_database: # Call the method await logger.async_post_call_failure_hook( diff --git a/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py b/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py index a5ee9ddf70da..e08bce432da3 100644 --- a/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py +++ b/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py @@ -416,7 +416,8 @@ async def test_spend_logs_payload_e2e(self): # litellm._turn_on_debug() with patch.object( - litellm.proxy.proxy_server, "_set_spend_logs_payload" + litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter, + "_set_spend_logs_payload", ) as mock_client, patch.object(litellm.proxy.proxy_server, "prisma_client"): response = await litellm.acompletion( model="gpt-4o", @@ -509,7 +510,8 @@ async def test_spend_logs_payload_success_log_with_api_base(self): client = AsyncHTTPHandler() with patch.object( - litellm.proxy.proxy_server, "_set_spend_logs_payload" + litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter, + "_set_spend_logs_payload", ) as mock_client, patch.object( litellm.proxy.proxy_server, "prisma_client" ), patch.object( @@ -604,7 +606,8 @@ async def test_spend_logs_payload_success_log_with_router(self): ) with patch.object( - litellm.proxy.proxy_server, "_set_spend_logs_payload" + litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter, + "_set_spend_logs_payload", ) as mock_client, patch.object( litellm.proxy.proxy_server, "prisma_client" ), patch.object( diff --git a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py index 535f5bf019bd..129be6d75469 100644 --- a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py +++ b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py @@ -5,6 +5,7 @@ import pytest from fastapi import Request from litellm.proxy.utils import _get_redoc_url, _get_docs_url +from datetime import datetime sys.path.insert(0, os.path.abspath("../..")) import litellm @@ -22,16 +23,20 @@ async def test_disable_spend_logs(): with patch("litellm.proxy.proxy_server.disable_spend_logs", True), patch( "litellm.proxy.proxy_server.prisma_client", mock_prisma_client ): - from litellm.proxy.proxy_server import update_database + from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter # Call update_database with disable_spend_logs=True - await update_database( + await DBSpendUpdateWriter.update_database( token="fake-token", response_cost=0.1, user_id="user123", completion_response=None, - start_time="2024-01-01", - end_time="2024-01-01", + start_time=datetime.now(), + end_time=datetime.now(), + end_user_id="end_user_id", + team_id="team_id", + org_id="org_id", + kwargs={}, ) # Verify no spend logs were added assert len(mock_prisma_client.spend_log_transactions) == 0