diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index e5d3dfd56c59..ba38c4fb9324 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -151,7 +151,13 @@ async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, - call_type: Literal["completion", "embeddings", "image_generation"], + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], ): pass diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 1a624c5f894d..5238c2b51ecc 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -90,6 +90,13 @@ from ..integrations.traceloop import TraceloopLogger from ..integrations.weights_biases import WeightsBiasesLogger +try: + from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import ( + GenericAPILogger, + ) +except Exception as e: + verbose_logger.debug(f"Exception import enterprise features {str(e)}") + _in_memory_loggers: List[Any] = [] ### GLOBAL VARIABLES ### @@ -145,7 +152,41 @@ def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> N return None +import hashlib + + +class DynamicLoggingCache: + """ + Prevent memory leaks caused by initializing new logging clients on each request. + + Relevant Issue: https://github.com/BerriAI/litellm/issues/5695 + """ + + def __init__(self) -> None: + self.cache = InMemoryCache() + + def get_cache_key(self, args: dict) -> str: + args_str = json.dumps(args, sort_keys=True) + cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest() + return cache_key + + def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]: + key_name = self.get_cache_key( + args={**credentials, "service_name": service_name} + ) + response = self.cache.get_cache(key=key_name) + return response + + def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None: + key_name = self.get_cache_key( + args={**credentials, "service_name": service_name} + ) + self.cache.set_cache(key=key_name, value=logging_obj) + return None + + in_memory_trace_id_cache = ServiceTraceIDCache() +in_memory_dynamic_logger_cache = DynamicLoggingCache() class Logging: @@ -324,10 +365,10 @@ def pre_call(self, input, api_key, model=None, additional_args={}): print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG") # log raw request to provider (like LangFuse) -- if opted in. if log_raw_request_response is True: + _litellm_params = self.model_call_details.get("litellm_params", {}) + _metadata = _litellm_params.get("metadata", {}) or {} try: # [Non-blocking Extra Debug Information in metadata] - _litellm_params = self.model_call_details.get("litellm_params", {}) - _metadata = _litellm_params.get("metadata", {}) or {} if ( turn_off_message_logging is not None and turn_off_message_logging is True @@ -362,7 +403,7 @@ def pre_call(self, input, api_key, model=None, additional_args={}): callbacks = litellm.input_callback + self.dynamic_input_callbacks for callback in callbacks: try: - if callback == "supabase": + if callback == "supabase" and supabaseClient is not None: verbose_logger.debug("reaches supabase for logging!") model = self.model_call_details["model"] messages = self.model_call_details["input"] @@ -396,7 +437,9 @@ def pre_call(self, input, api_key, model=None, additional_args={}): messages=self.messages, kwargs=self.model_call_details, ) - elif callable(callback): # custom logger functions + elif ( + callable(callback) and customLogger is not None + ): # custom logger functions customLogger.log_input_event( model=self.model, messages=self.messages, @@ -615,7 +658,7 @@ def _success_handler_helper_fn( self.model_call_details["litellm_params"]["metadata"][ "hidden_params" - ] = result._hidden_params + ] = getattr(result, "_hidden_params", {}) ## STANDARDIZED LOGGING PAYLOAD self.model_call_details["standard_logging_object"] = ( @@ -645,6 +688,7 @@ def _success_handler_helper_fn( litellm.max_budget and self.stream is False and result is not None + and isinstance(result, dict) and "content" in result ): time_diff = (end_time - start_time).total_seconds() @@ -652,7 +696,7 @@ def _success_handler_helper_fn( litellm._current_cost += litellm.completion_cost( model=self.model, prompt="", - completion=result["content"], + completion=getattr(result, "content", ""), total_time=float_diff, ) @@ -758,7 +802,7 @@ def success_handler( ): print_verbose("no-log request, skipping logging") continue - if callback == "lite_debugger": + if callback == "lite_debugger" and liteDebuggerClient is not None: print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose( @@ -774,7 +818,7 @@ def success_handler( call_type=self.call_type, stream=self.stream, ) - if callback == "promptlayer": + if callback == "promptlayer" and promptLayerLogger is not None: print_verbose("reaches promptlayer for logging!") promptLayerLogger.log_event( kwargs=self.model_call_details, @@ -783,7 +827,7 @@ def success_handler( end_time=end_time, print_verbose=print_verbose, ) - if callback == "supabase": + if callback == "supabase" and supabaseClient is not None: print_verbose("reaches supabase for logging!") kwargs = self.model_call_details @@ -811,7 +855,7 @@ def success_handler( ), print_verbose=print_verbose, ) - if callback == "wandb": + if callback == "wandb" and weightsBiasesLogger is not None: print_verbose("reaches wandb for logging!") weightsBiasesLogger.log_event( kwargs=self.model_call_details, @@ -820,7 +864,7 @@ def success_handler( end_time=end_time, print_verbose=print_verbose, ) - if callback == "logfire": + if callback == "logfire" and logfireLogger is not None: global logfireLogger verbose_logger.debug("reaches logfire for success logging!") kwargs = {} @@ -844,10 +888,10 @@ def success_handler( start_time=start_time, end_time=end_time, print_verbose=print_verbose, - level=LogfireLevel.INFO.value, + level=LogfireLevel.INFO.value, # type: ignore ) - if callback == "lunary": + if callback == "lunary" and lunaryLogger is not None: print_verbose("reaches lunary for logging!") model = self.model kwargs = self.model_call_details @@ -882,7 +926,7 @@ def success_handler( run_id=self.litellm_call_id, print_verbose=print_verbose, ) - if callback == "helicone": + if callback == "helicone" and heliconeLogger is not None: print_verbose("reaches helicone for logging!") model = self.model messages = self.model_call_details["input"] @@ -924,6 +968,7 @@ def success_handler( else: print_verbose("reaches langfuse for streaming logging!") result = kwargs["complete_streaming_response"] + temp_langfuse_logger = langFuseLogger if langFuseLogger is None or ( ( @@ -941,27 +986,45 @@ def success_handler( and self.langfuse_host != langFuseLogger.langfuse_host ) ): - temp_langfuse_logger = LangFuseLogger( - langfuse_public_key=self.langfuse_public_key, - langfuse_secret=self.langfuse_secret, - langfuse_host=self.langfuse_host, + credentials = { + "langfuse_public_key": self.langfuse_public_key, + "langfuse_secret": self.langfuse_secret, + "langfuse_host": self.langfuse_host, + } + temp_langfuse_logger = ( + in_memory_dynamic_logger_cache.get_cache( + credentials=credentials, service_name="langfuse" + ) ) - _response = temp_langfuse_logger.log_event( - kwargs=kwargs, - response_obj=result, - start_time=start_time, - end_time=end_time, - user_id=kwargs.get("user", None), - print_verbose=print_verbose, - ) - if _response is not None and isinstance(_response, dict): - _trace_id = _response.get("trace_id", None) - if _trace_id is not None: - in_memory_trace_id_cache.set_cache( - litellm_call_id=self.litellm_call_id, + if temp_langfuse_logger is None: + temp_langfuse_logger = LangFuseLogger( + langfuse_public_key=self.langfuse_public_key, + langfuse_secret=self.langfuse_secret, + langfuse_host=self.langfuse_host, + ) + in_memory_dynamic_logger_cache.set_cache( + credentials=credentials, service_name="langfuse", - trace_id=_trace_id, + logging_obj=temp_langfuse_logger, ) + + if temp_langfuse_logger is not None: + _response = temp_langfuse_logger.log_event( + kwargs=kwargs, + response_obj=result, + start_time=start_time, + end_time=end_time, + user_id=kwargs.get("user", None), + print_verbose=print_verbose, + ) + if _response is not None and isinstance(_response, dict): + _trace_id = _response.get("trace_id", None) + if _trace_id is not None: + in_memory_trace_id_cache.set_cache( + litellm_call_id=self.litellm_call_id, + service_name="langfuse", + trace_id=_trace_id, + ) if callback == "generic": global genericAPILogger verbose_logger.debug("reaches langfuse for success logging!") @@ -982,7 +1045,7 @@ def success_handler( print_verbose("reaches langfuse for streaming logging!") result = kwargs["complete_streaming_response"] if genericAPILogger is None: - genericAPILogger = GenericAPILogger() + genericAPILogger = GenericAPILogger() # type: ignore genericAPILogger.log_event( kwargs=kwargs, response_obj=result, @@ -1022,7 +1085,7 @@ def success_handler( user_id=kwargs.get("user", None), print_verbose=print_verbose, ) - if callback == "greenscale": + if callback == "greenscale" and greenscaleLogger is not None: kwargs = {} for k, v in self.model_call_details.items(): if ( @@ -1066,7 +1129,7 @@ def success_handler( result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) - if callback == "athina": + if callback == "athina" and athinaLogger is not None: deep_copy = {} for k, v in self.model_call_details.items(): deep_copy[k] = v @@ -1224,6 +1287,7 @@ def success_handler( "atranscription", False ) is not True + and customLogger is not None ): # custom logger functions print_verbose( f"success callbacks: Running Custom Callback Function" @@ -1423,9 +1487,9 @@ async def async_success_handler( await litellm.cache.async_add_cache(result, **kwargs) else: litellm.cache.add_cache(result, **kwargs) - if callback == "openmeter": + if callback == "openmeter" and openMeterLogger is not None: global openMeterLogger - if self.stream == True: + if self.stream is True: if ( "async_complete_streaming_response" in self.model_call_details @@ -1645,33 +1709,9 @@ def failure_handler( ) for callback in callbacks: try: - if callback == "lite_debugger": - print_verbose("reaches lite_debugger for logging!") - print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - result = { - "model": self.model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - self.model, messages=self.messages - ), - "completion_tokens": 0, - }, - } - liteDebuggerClient.log_event( - model=self.model, - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.litellm_call_id, - print_verbose=print_verbose, - call_type=self.call_type, - stream=self.stream, - ) - if callback == "lunary": + if callback == "lite_debugger" and liteDebuggerClient is not None: + pass + elif callback == "lunary" and lunaryLogger is not None: print_verbose("reaches lunary for logging error!") model = self.model @@ -1685,6 +1725,7 @@ def failure_handler( ) lunaryLogger.log_event( + kwargs=self.model_call_details, type=_type, event="error", user_id=self.model_call_details.get("user", "default"), @@ -1704,22 +1745,11 @@ def failure_handler( print_verbose( f"capture exception not initialized: {capture_exception}" ) - elif callback == "supabase": + elif callback == "supabase" and supabaseClient is not None: print_verbose("reaches supabase for logging!") print_verbose(f"supabaseClient: {supabaseClient}") - result = { - "model": model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - model, messages=self.messages - ), - "completion_tokens": 0, - }, - } supabaseClient.log_event( - model=self.model, + model=self.model if hasattr(self, "model") else "", messages=self.messages, end_user=self.model_call_details.get("user", "default"), response_obj=result, @@ -1728,7 +1758,9 @@ def failure_handler( litellm_call_id=self.model_call_details["litellm_call_id"], print_verbose=print_verbose, ) - if callable(callback): # custom logger functions + if ( + callable(callback) and customLogger is not None + ): # custom logger functions customLogger.log_event( kwargs=self.model_call_details, response_obj=result, @@ -1809,13 +1841,13 @@ def failure_handler( start_time=start_time, end_time=end_time, response_obj=None, - user_id=kwargs.get("user", None), + user_id=self.model_call_details.get("user", None), print_verbose=print_verbose, status_message=str(exception), level="ERROR", kwargs=self.model_call_details, ) - if callback == "logfire": + if callback == "logfire" and logfireLogger is not None: verbose_logger.debug("reaches logfire for failure logging!") kwargs = {} for k, v in self.model_call_details.items(): @@ -1830,7 +1862,7 @@ def failure_handler( response_obj=result, start_time=start_time, end_time=end_time, - level=LogfireLevel.ERROR.value, + level=LogfireLevel.ERROR.value, # type: ignore print_verbose=print_verbose, ) @@ -1873,7 +1905,9 @@ async def async_failure_handler( start_time=start_time, end_time=end_time, ) # type: ignore - if callable(callback): # custom logger functions + if ( + callable(callback) and customLogger is not None + ): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, @@ -1966,7 +2000,7 @@ def set_callbacks(callback_list, function_id=None): ) sentry_sdk_instance.init( dsn=os.environ.get("SENTRY_DSN"), - traces_sample_rate=float(sentry_trace_rate), + traces_sample_rate=float(sentry_trace_rate), # type: ignore ) capture_exception = sentry_sdk_instance.capture_exception add_breadcrumb = sentry_sdk_instance.add_breadcrumb @@ -2411,12 +2445,11 @@ def get_standard_logging_object_payload( saved_cache_cost: Optional[float] = None if cache_hit is True: - import time id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id saved_cache_cost = logging_obj._response_cost_calculator( - result=init_response_obj, cache_hit=False + result=init_response_obj, cache_hit=False # type: ignore ) ## Get model cost information ## @@ -2473,7 +2506,7 @@ def get_standard_logging_object_payload( model_id=_model_id, requester_ip_address=clean_metadata.get("requester_ip_address", None), messages=kwargs.get("messages"), - response=( + response=( # type: ignore response_obj if len(response_obj.keys()) > 0 else init_response_obj ), model_parameters=kwargs.get("optional_params", None), diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b2e63c2564cc..36fe8ce73335 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -446,7 +446,6 @@ async def user_api_key_auth( and request.headers.get(key=header_key) is not None # type: ignore ): api_key = request.headers.get(key=header_key) # type: ignore - if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ef4903054f53..fa0be1f9e7d0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -92,6 +92,7 @@ def safe_deep_copy(data): if litellm.safe_memory_mode is True: return data + litellm_parent_otel_span: Optional[Any] = None # Step 1: Remove the litellm_parent_otel_span if isinstance(data, dict): # remove litellm_parent_otel_span since this is not picklable @@ -100,7 +101,7 @@ def safe_deep_copy(data): new_data = copy.deepcopy(data) # Step 2: re-add the litellm_parent_otel_span after doing a deep copy - if isinstance(data, dict): + if isinstance(data, dict) and litellm_parent_otel_span is not None: if "metadata" in data: data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span return new_data @@ -467,7 +468,7 @@ async def during_call_hook( # V1 implementation - backwards compatibility if callback.event_hook is None: - if callback.moderation_check == "pre_call": + if callback.moderation_check == "pre_call": # type: ignore return else: # Main - V2 Guardrails implementation @@ -992,7 +993,7 @@ async def check_view_exists(self): return else: ## check if required view exists ## - if required_view not in ret[0]["view_names"]: + if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: await self.health_check() # make sure we can connect to db await self.db.execute_raw( """ @@ -1014,7 +1015,9 @@ async def check_view_exists(self): else: # don't block execution if these views are missing # Convert lists to sets for efficient difference calculation - ret_view_names_set = set(ret[0]["view_names"]) + ret_view_names_set = ( + set(ret[0]["view_names"]) if ret[0]["view_names"] else set() + ) expected_views_set = set(expected_views) # Find missing views missing_views = expected_views_set - ret_view_names_set @@ -1296,6 +1299,7 @@ async def get_data( verbose_proxy_logger.debug( f"PrismaClient: get_data - args_passed_in: {args_passed_in}" ) + hashed_token: Optional[str] = None try: response: Any = None if (token is not None and table_name is None) or ( @@ -1310,7 +1314,7 @@ async def get_data( verbose_proxy_logger.debug( f"PrismaClient: find_unique for token: {hashed_token}" ) - if query_type == "find_unique": + if query_type == "find_unique" and hashed_token is not None: if token is None: raise HTTPException( status_code=400, @@ -1711,7 +1715,7 @@ async def insert_data( updated_data = json.dumps(updated_data) updated_table_row = self.db.litellm_config.upsert( where={"param_name": k}, - data={ + data={ # type: ignore "create": {"param_name": k, "param_value": updated_data}, "update": {"param_value": updated_data}, }, @@ -2265,11 +2269,13 @@ async def disconnect(self): """ For closing connection on server shutdown """ - return await self.db.disconnect() + return await self.db.disconnect() # type: ignore ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: + instance_name: Optional[str] = None + module_name: Optional[str] = None try: print_verbose(f"value: {value}") # Split the path by dots to separate module from instance @@ -2302,7 +2308,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: return instance except ImportError as e: # Re-raise the exception with a user-friendly message - raise ImportError(f"Could not import {instance_name} from {module_name}") from e + if instance_name and module_name: + raise ImportError( + f"Could not import {instance_name} from {module_name}" + ) from e + else: + raise e except Exception as e: raise e @@ -2368,12 +2379,12 @@ async def send_email(receiver_email, subject, html): try: # Establish a secure connection with the SMTP server - with smtplib.SMTP(smtp_host, smtp_port) as server: + with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore if os.getenv("SMTP_TLS", "True") != "False": server.starttls() # Login to your email account - server.login(smtp_username, smtp_password) + server.login(smtp_username, smtp_password) # type: ignore # Send the email server.send_message(email_message) @@ -2930,10 +2941,10 @@ async def update_spend( ) break except httpx.ReadTimeout: - if i >= n_retry_times: # If we've reached the maximum number of retries + if i >= n_retry_times: # type: ignore raise # Re-raise the last exception # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff + await asyncio.sleep(2**i) # type: ignore except Exception as e: import traceback @@ -3044,10 +3055,10 @@ def get_error_message_str(e: Exception) -> str: elif isinstance(e.detail, dict): error_message = json.dumps(e.detail) elif hasattr(e, "message"): - if isinstance(e.message, "str"): - error_message = e.message - elif isinstance(e.message, dict): - error_message = json.dumps(e.message) + if isinstance(e.message, "str"): # type: ignore + error_message = e.message # type: ignore + elif isinstance(e.message, dict): # type: ignore + error_message = json.dumps(e.message) # type: ignore else: error_message = str(e) else: diff --git a/litellm/utils.py b/litellm/utils.py index 280691c8a4c1..58ef4d49f1de 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -120,11 +120,26 @@ # Convert to str (if necessary) claude_json_str = json.dumps(json_data) import importlib.metadata +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + Type, + Union, + cast, + get_args, +) from openai import OpenAIError as OriginalError from ._logging import verbose_logger -from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache +from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache from .exceptions import ( APIConnectionError, APIError, @@ -150,31 +165,6 @@ ) from .types.router import LiteLLM_Params -try: - from .proxy.enterprise.enterprise_callbacks.generic_api_callback import ( - GenericAPILogger, - ) -except Exception as e: - verbose_logger.debug(f"Exception import enterprise features {str(e)}") - -from concurrent.futures import ThreadPoolExecutor -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Tuple, - Type, - Union, - cast, - get_args, -) - -from .caching import Cache - ####### ENVIRONMENT VARIABLES #################### # Adjust to your specific application needs / system capabilities. MAX_THREADS = 100