From 3d79cfb06bed9cd32bf25482678724cfed60708e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 08:56:55 -0700 Subject: [PATCH 01/14] build(pyproject.toml): add new dev dependencies - for type checking --- .github/workflows/test-linting.yml | 53 ++++++++++++++ poetry.lock | 113 ++++++++++++++++++++++++++++- pyproject.toml | 4 + 3 files changed, 166 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/test-linting.yml diff --git a/.github/workflows/test-linting.yml b/.github/workflows/test-linting.yml new file mode 100644 index 000000000000..d117d12de486 --- /dev/null +++ b/.github/workflows/test-linting.yml @@ -0,0 +1,53 @@ +name: LiteLLM Linting + +on: + pull_request: + branches: [ main ] + +jobs: + lint: + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install Poetry + uses: snok/install-poetry@v1 + + - name: Install dependencies + run: | + poetry install --with dev + + - name: Run Black formatting check + run: | + cd litellm + poetry run black . --check + cd .. + + - name: Run Ruff linting + run: | + cd litellm + poetry run ruff check . + cd .. + + - name: Run MyPy type checking + run: | + cd litellm + poetry run mypy . --ignore-missing-imports + cd .. + + - name: Check for circular imports + run: | + cd litellm + poetry run python ../tests/documentation_tests/test_circular_imports.py + cd .. + + - name: Check import safety + run: | + poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 64e8eaa8ff66..9c42278c14bf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -436,7 +436,7 @@ files = [ name = "cffi" version = "1.17.1" description = "Foreign Function Interface for Python calling C code." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, @@ -658,7 +658,7 @@ cron = ["capturer (>=2.4)"] name = "cryptography" version = "43.0.3" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, @@ -2563,7 +2563,7 @@ files = [ name = "pycparser" version = "2.22" description = "C parser in Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, @@ -3603,6 +3603,111 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "types-cffi" +version = "1.16.0.20241221" +description = "Typing stubs for cffi" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_cffi-1.16.0.20241221-py3-none-any.whl", hash = "sha256:e5b76b4211d7a9185f6ab8d06a106d56c7eb80af7cdb8bfcb4186ade10fb112f"}, + {file = "types_cffi-1.16.0.20241221.tar.gz", hash = "sha256:1c96649618f4b6145f58231acb976e0b448be6b847f7ab733dabe62dfbff6591"}, +] + +[package.dependencies] +types-setuptools = "*" + +[[package]] +name = "types-pyopenssl" +version = "24.1.0.20240722" +description = "Typing stubs for pyOpenSSL" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pyOpenSSL-24.1.0.20240722.tar.gz", hash = "sha256:47913b4678a01d879f503a12044468221ed8576263c1540dcb0484ca21b08c39"}, + {file = "types_pyOpenSSL-24.1.0.20240722-py3-none-any.whl", hash = "sha256:6a7a5d2ec042537934cfb4c9d4deb0e16c4c6250b09358df1f083682fe6fda54"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-cffi = "*" + +[[package]] +name = "types-pyyaml" +version = "6.0.12.20241230" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_PyYAML-6.0.12.20241230-py3-none-any.whl", hash = "sha256:fa4d32565219b68e6dee5f67534c722e53c00d1cfc09c435ef04d7353e1e96e6"}, + {file = "types_pyyaml-6.0.12.20241230.tar.gz", hash = "sha256:7f07622dbd34bb9c8b264fe860a17e0efcad00d50b5f27e93984909d9363498c"}, +] + +[[package]] +name = "types-redis" +version = "4.6.0.20241004" +description = "Typing stubs for redis" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e"}, + {file = "types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-pyOpenSSL = "*" + +[[package]] +name = "types-requests" +version = "2.31.0.6" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, +] + +[package.dependencies] +types-urllib3 = "*" + +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "types-setuptools" +version = "75.8.0.20250110" +description = "Typing stubs for setuptools" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_setuptools-75.8.0.20250110-py3-none-any.whl", hash = "sha256:a9f12980bbf9bcdc23ecd80755789085bad6bfce4060c2275bc2b4ca9f2bc480"}, + {file = "types_setuptools-75.8.0.20250110.tar.gz", hash = "sha256:96f7ec8bbd6e0a54ea180d66ad68ad7a1d7954e7281a710ea2de75e355545271"}, +] + +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.13.0" @@ -3993,4 +4098,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "0f195796116a7c7a4a04d9958a7662d74baccb5266f531bb58a4403fd4db4e0e" +content-hash = "efd348b2920530f18696f75f684026be542366eeec2f0b13f53ed9277a71c56a" diff --git a/pyproject.toml b/pyproject.toml index 2dbfcc39de18..fd17cdca178d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,10 @@ pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-asyncio = "^0.21.1" respx = "^0.20.2" +types-requests = "*" +types-setuptools = "*" +types-redis = "*" +types-PyYAML = "*" [tool.poetry.group.proxy-dev.dependencies] prisma = "0.11.0" From 9ba97bd37dfbd406cd4ab8ca2d8151700fb1934c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:16:19 -0700 Subject: [PATCH 02/14] build: reformat files to fit black --- .pre-commit-config.yaml | 8 +- .../enterprise_hooks/secret_detection.py | 4 - litellm/__init__.py | 98 +++--- litellm/batches/main.py | 2 - litellm/caching/caching_handler.py | 11 +- litellm/caching/llm_caching_handler.py | 1 - litellm/caching/redis_cache.py | 1 - litellm/caching/redis_semantic_cache.py | 120 ++++--- litellm/cost_calculator.py | 1 - litellm/fine_tuning/main.py | 3 - .../SlackAlerting/batching_handler.py | 1 - .../SlackAlerting/slack_alerting.py | 28 +- litellm/integrations/_types/open_inference.py | 2 +- litellm/integrations/argilla.py | 1 - litellm/integrations/arize/arize.py | 1 - litellm/integrations/arize/arize_phoenix.py | 23 +- litellm/integrations/athina.py | 5 +- .../azure_storage/azure_storage.py | 13 +- litellm/integrations/braintrust_logging.py | 103 ++++-- litellm/integrations/custom_batch_logger.py | 1 - litellm/integrations/custom_guardrail.py | 1 - litellm/integrations/datadog/datadog.py | 1 - .../gcs_bucket/gcs_bucket_base.py | 6 +- litellm/integrations/gcs_pubsub/pub_sub.py | 13 +- litellm/integrations/humanloop.py | 6 +- litellm/integrations/langfuse/langfuse.py | 6 +- .../integrations/langfuse/langfuse_handler.py | 5 +- .../langfuse/langfuse_prompt_management.py | 6 +- litellm/integrations/langsmith.py | 8 +- litellm/integrations/lunary.py | 2 - litellm/integrations/mlflow.py | 9 +- litellm/integrations/opentelemetry.py | 12 +- litellm/integrations/opik/opik.py | 1 - litellm/integrations/prometheus.py | 6 +- .../integrations/prompt_management_base.py | 7 +- litellm/integrations/weights_biases.py | 8 +- .../litellm_core_utils/default_encoding.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 317 +++++++++--------- .../convert_dict_to_response.py | 8 +- .../litellm_core_utils/model_param_helper.py | 1 - .../prompt_templates/common_utils.py | 1 - .../prompt_templates/factory.py | 26 +- .../litellm_core_utils/realtime_streaming.py | 1 - litellm/litellm_core_utils/redact_messages.py | 6 +- .../sensitive_data_masker.py | 6 +- .../streaming_chunk_builder_utils.py | 1 - .../litellm_core_utils/streaming_handler.py | 20 +- litellm/llms/anthropic/chat/handler.py | 16 +- litellm/llms/anthropic/chat/transformation.py | 46 +-- .../anthropic/completion/transformation.py | 6 +- .../messages/handler.py | 28 +- litellm/llms/azure/azure.py | 6 - litellm/llms/azure/batches/handler.py | 72 ++-- litellm/llms/azure/common_utils.py | 1 - litellm/llms/azure/files/handler.py | 91 +++-- litellm/llms/azure/fine_tuning/handler.py | 9 +- litellm/llms/azure_ai/chat/transformation.py | 1 - .../azure_ai/embed/cohere_transformation.py | 1 - litellm/llms/azure_ai/embed/handler.py | 2 - .../llms/azure_ai/rerank/transformation.py | 1 + litellm/llms/base.py | 1 - litellm/llms/base_llm/chat/transformation.py | 1 - .../llms/base_llm/responses/transformation.py | 1 - litellm/llms/bedrock/chat/converse_handler.py | 10 +- .../bedrock/chat/converse_transformation.py | 20 +- litellm/llms/bedrock/chat/invoke_handler.py | 14 +- .../base_invoke_transformation.py | 7 +- litellm/llms/bedrock/common_utils.py | 1 - .../amazon_titan_multimodal_transformation.py | 7 +- .../amazon_nova_canvas_transformation.py | 79 +++-- litellm/llms/bedrock/image/image_handler.py | 13 +- litellm/llms/bedrock/rerank/handler.py | 1 - litellm/llms/bedrock/rerank/transformation.py | 1 - litellm/llms/codestral/completion/handler.py | 3 - .../codestral/completion/transformation.py | 1 - litellm/llms/cohere/chat/transformation.py | 2 - litellm/llms/cohere/embed/handler.py | 1 - litellm/llms/cohere/embed/transformation.py | 2 - litellm/llms/cohere/rerank/transformation.py | 2 +- .../llms/cohere/rerank_v2/transformation.py | 3 +- litellm/llms/custom_httpx/aiohttp_handler.py | 2 - litellm/llms/custom_httpx/http_handler.py | 4 - litellm/llms/custom_httpx/llm_http_handler.py | 5 - litellm/llms/databricks/common_utils.py | 6 +- .../llms/databricks/embed/transformation.py | 6 +- litellm/llms/databricks/streaming_utils.py | 1 - .../audio_transcription/transformation.py | 6 +- litellm/llms/deepseek/chat/transformation.py | 1 - .../llms/deprecated_providers/aleph_alpha.py | 6 +- .../llms/fireworks_ai/chat/transformation.py | 3 - litellm/llms/gemini/chat/transformation.py | 1 - litellm/llms/gemini/common_utils.py | 1 - litellm/llms/groq/chat/transformation.py | 1 - litellm/llms/groq/stt/transformation.py | 1 - .../llms/huggingface/chat/transformation.py | 24 +- litellm/llms/maritalk.py | 1 - .../llms/ollama/completion/transformation.py | 6 +- .../llms/openai/chat/gpt_transformation.py | 1 - litellm/llms/openai/completion/handler.py | 1 - .../llms/openai/completion/transformation.py | 6 +- litellm/llms/openai/fine_tuning/handler.py | 9 +- litellm/llms/openai/openai.py | 25 +- .../transcriptions/whisper_transformation.py | 6 +- .../llms/openrouter/chat/transformation.py | 8 +- .../llms/petals/completion/transformation.py | 6 +- litellm/llms/predibase/chat/handler.py | 1 - litellm/llms/predibase/chat/transformation.py | 12 +- litellm/llms/replicate/chat/handler.py | 1 - litellm/llms/sagemaker/chat/handler.py | 2 - litellm/llms/sagemaker/common_utils.py | 2 - litellm/llms/sagemaker/completion/handler.py | 10 +- .../sagemaker/completion/transformation.py | 6 +- .../llms/together_ai/rerank/transformation.py | 1 - .../topaz/image_variations/transformation.py | 2 - .../llms/triton/completion/transformation.py | 1 - litellm/llms/vertex_ai/batches/handler.py | 7 +- litellm/llms/vertex_ai/common_utils.py | 4 +- litellm/llms/vertex_ai/files/handler.py | 9 +- litellm/llms/vertex_ai/fine_tuning/handler.py | 15 +- .../llms/vertex_ai/gemini/transformation.py | 26 +- .../vertex_and_google_ai_studio_gemini.py | 18 +- .../batch_embed_content_handler.py | 1 - .../batch_embed_content_transformation.py | 1 - .../embedding_handler.py | 1 - .../multimodal_embeddings/transformation.py | 1 - .../llms/vertex_ai/vertex_ai_non_gemini.py | 2 - .../vertex_ai_partner_models/main.py | 1 - .../vertex_embeddings/embedding_handler.py | 12 +- .../vertex_embeddings/transformation.py | 1 - .../vertex_ai/vertex_model_garden/main.py | 1 - litellm/llms/watsonx/chat/transformation.py | 1 - litellm/main.py | 56 ++-- litellm/proxy/_types.py | 66 ++-- litellm/proxy/auth/auth_checks.py | 3 - .../proxy/auth/auth_checks_organization.py | 14 +- litellm/proxy/auth/auth_exception_handler.py | 1 - litellm/proxy/auth/auth_utils.py | 2 - litellm/proxy/auth/handle_jwt.py | 23 +- litellm/proxy/auth/litellm_license.py | 1 - litellm/proxy/auth/model_checks.py | 1 - litellm/proxy/auth/route_checks.py | 3 - litellm/proxy/auth/user_api_key_auth.py | 28 +- .../common_utils/encrypt_decrypt_utils.py | 2 - .../proxy/common_utils/http_parsing_utils.py | 8 +- litellm/proxy/db/log_db_metrics.py | 1 - litellm/proxy/db/redis_update_buffer.py | 22 +- litellm/proxy/guardrails/guardrail_helpers.py | 1 - .../guardrail_hooks/bedrock_guardrails.py | 1 - .../guardrails/guardrail_hooks/lakera_ai.py | 6 +- .../guardrails/guardrail_hooks/presidio.py | 10 +- litellm/proxy/hooks/dynamic_rate_limiter.py | 58 ++-- .../proxy/hooks/key_management_event_hooks.py | 1 - .../proxy/hooks/parallel_request_limiter.py | 32 +- .../proxy/hooks/prompt_injection_detection.py | 1 - .../proxy/hooks/proxy_track_cost_callback.py | 8 +- litellm/proxy/litellm_pre_call_utils.py | 16 +- .../budget_management_endpoints.py | 1 - .../management_endpoints/common_utils.py | 1 - .../customer_endpoints.py | 1 - .../internal_user_endpoints.py | 41 ++- .../key_management_endpoints.py | 41 ++- .../model_management_endpoints.py | 3 - .../organization_endpoints.py | 43 +-- .../management_endpoints/team_endpoints.py | 28 +- litellm/proxy/management_endpoints/ui_sso.py | 12 +- litellm/proxy/management_helpers/utils.py | 3 +- .../openai_files_endpoints/files_endpoints.py | 1 - .../anthropic_passthrough_logging_handler.py | 7 +- .../vertex_passthrough_logging_handler.py | 5 +- .../pass_through_endpoints.py | 3 - .../passthrough_endpoint_router.py | 6 +- .../streaming_handler.py | 1 - litellm/proxy/prisma_migration.py | 30 +- litellm/proxy/proxy_server.py | 31 +- .../spend_management_endpoints.py | 6 - litellm/proxy/utils.py | 7 +- .../vertex_ai_endpoints/langfuse_endpoints.py | 8 +- litellm/rerank_api/main.py | 18 +- litellm/responses/main.py | 27 +- litellm/router.py | 92 ++--- .../router_strategy/base_routing_strategy.py | 6 +- litellm/router_strategy/budget_limiter.py | 41 +-- litellm/router_strategy/lowest_tpm_rpm_v2.py | 1 - litellm/router_utils/cooldown_callbacks.py | 6 +- .../router_utils/pattern_match_deployments.py | 12 +- .../secret_managers/aws_secret_manager_v2.py | 1 - litellm/types/integrations/arize_phoenix.py | 4 +- litellm/types/llms/openai.py | 36 +- litellm/types/rerank.py | 7 +- litellm/types/router.py | 29 +- litellm/types/utils.py | 3 - litellm/utils.py | 53 ++- 192 files changed, 1260 insertions(+), 1361 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fb37f3252498..4818ca6ca0fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,10 +14,10 @@ repos: types: [python] files: litellm/.*\.py exclude: ^litellm/__init__.py$ -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black +# - repo: https://github.com/psf/black +# rev: 24.2.0 +# hooks: +# - id: black - repo: https://github.com/pycqa/flake8 rev: 7.0.0 # The version of flake8 to use hooks: diff --git a/enterprise/enterprise_hooks/secret_detection.py b/enterprise/enterprise_hooks/secret_detection.py index 459fd374d1dd..158f26efa30e 100644 --- a/enterprise/enterprise_hooks/secret_detection.py +++ b/enterprise/enterprise_hooks/secret_detection.py @@ -444,9 +444,7 @@ def scan_message_for_secrets(self, message_content: str): detected_secrets = [] for file in secrets.files: - for found_secret in secrets[file]: - if found_secret.secret_value is None: continue detected_secrets.append( @@ -471,14 +469,12 @@ async def async_pre_call_hook( data: dict, call_type: str, # "completion", "embeddings", "image_generation", "moderation" ): - if await self.should_run_check(user_api_key_dict) is False: return if "messages" in data and isinstance(data["messages"], list): for message in data["messages"]: if "content" in message and isinstance(message["content"], str): - detected_secrets = self.scan_message_for_secrets(message["content"]) for secret in detected_secrets: diff --git a/litellm/__init__.py b/litellm/__init__.py index a4903f828cf7..c2e366e2b13e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -122,19 +122,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload -gcs_pub_sub_use_v1: Optional[bool] = ( - False # if you want to use v1 gcs pubsub logged payload -) +gcs_pub_sub_use_v1: Optional[ + bool +] = False # if you want to use v1 gcs pubsub logged payload argilla_transformation_object: Optional[Dict[str, Any]] = None -_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. -_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. +_async_input_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. +_async_success_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False @@ -142,18 +142,18 @@ redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[bool] = ( - None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers -) +add_user_information_to_llm_headers: Optional[ + bool +] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[str] = ( - None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -token: Optional[str] = ( - None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) +email: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +token: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -229,24 +229,20 @@ enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = ( - False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -caching_with_models: bool = ( - False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -cache: Optional[Cache] = ( - None # cache object <- use this - https://docs.litellm.ai/docs/caching -) +caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +cache: Optional[ + Cache +] = None # cache object <- use this - https://docs.litellm.ai/docs/caching default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[str] = ( - None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). -) +budget_duration: Optional[ + str +] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). default_soft_budget: float = ( 50.0 # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -255,15 +251,11 @@ _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = ( - False # if function calling not supported by api, append function call details to system prompt -) +add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = ( - "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" -) +model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -285,9 +277,7 @@ custom_prometheus_metadata_labels: List[str] = [] #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None -force_ipv4: bool = ( - False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. -) +force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) @@ -301,13 +291,13 @@ context_window_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 -num_retries_per_request: Optional[int] = ( - None # for the request overall (incl. fallbacks + model retries) -) +num_retries_per_request: Optional[ + int +] = None # for the request overall (incl. fallbacks + model retries) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[Any] = ( - None # list of instantiated key management clients - e.g. azure kv, infisical, etc. -) +secret_manager_client: Optional[ + Any +] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings() @@ -1056,10 +1046,10 @@ def add_known_models(): from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[str] = ( - [] -) # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[bool] = ( - None # disable huggingface tokenizer download. Defaults to openai clk100 -) +_custom_providers: List[ + str +] = [] # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[ + bool +] = None # disable huggingface tokenizer download. Defaults to openai clk100 global_disable_no_log_param: bool = False diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 1ddcafce4ce7..f4f74c72fb0c 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -153,7 +153,6 @@ def create_batch( ) api_base: Optional[str] = None if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -358,7 +357,6 @@ def retrieve_batch( _is_async = kwargs.pop("aretrieve_batch", False) is True api_base: Optional[str] = None if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 09fabf1c1210..14278de9cd53 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel): cached_result: Optional[Any] = None final_embedding_cached_response: Optional[EmbeddingResponse] = None - embedding_all_elements_cache_hit: bool = ( - False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call - ) + embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call class LLMCachingHandler: @@ -738,7 +736,6 @@ def sync_set_cache( if self._should_store_result_in_cache( original_function=self.original_function, kwargs=new_kwargs ): - litellm.cache.add_cache(result, **new_kwargs) return @@ -865,9 +862,9 @@ def _update_litellm_logging_obj_environment( } if litellm.cache is not None: - litellm_params["preset_cache_key"] = ( - litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) - ) + litellm_params[ + "preset_cache_key" + ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) else: litellm_params["preset_cache_key"] = None diff --git a/litellm/caching/llm_caching_handler.py b/litellm/caching/llm_caching_handler.py index 429634b7b1f1..3bf1f80d0884 100644 --- a/litellm/caching/llm_caching_handler.py +++ b/litellm/caching/llm_caching_handler.py @@ -8,7 +8,6 @@ class LLMClientCache(InMemoryCache): - def update_cache_key_with_event_loop(self, key): """ Add the event loop to the cache key, to prevent event loop closed errors. diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 7378ed878abb..29fa44715374 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -57,7 +57,6 @@ def __init__( socket_timeout: Optional[float] = 5.0, # default 5 second timeout **kwargs, ): - from litellm._service_logger import ServiceLogging from .._redis import get_redis_client, get_redis_connection_pool diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index f46bb661ef36..6ebd1c8f33a9 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -13,23 +13,27 @@ import asyncio import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast import litellm from litellm._logging import print_verbose -from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_str_from_messages, +) +from litellm.types.utils import EmbeddingResponse + from .base_cache import BaseCache class RedisSemanticCache(BaseCache): """ - Redis-backed semantic cache for LLM responses. - - This cache uses vector similarity to find semantically similar prompts that have been + Redis-backed semantic cache for LLM responses. + + This cache uses vector similarity to find semantically similar prompts that have been previously sent to the LLM, allowing for cache hits even when prompts are not identical but carry similar meaning. """ - + DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index" def __init__( @@ -57,7 +61,7 @@ def __init__( index_name: Name for the Redis index ttl: Default time-to-live for cache entries in seconds **kwargs: Additional arguments passed to the Redis client - + Raises: Exception: If similarity_threshold is not provided or required Redis connection information is missing @@ -69,14 +73,14 @@ def __init__( index_name = self.DEFAULT_REDIS_INDEX_NAME print_verbose(f"Redis semantic-cache initializing index - {index_name}") - + # Validate similarity threshold if similarity_threshold is None: raise ValueError("similarity_threshold must be provided, passed None") - + # Store configuration self.similarity_threshold = similarity_threshold - + # Convert similarity threshold [0,1] to distance threshold [0,2] # For cosine distance: 0 = most similar, 2 = least similar # While similarity: 1 = most similar, 0 = least similar @@ -87,14 +91,16 @@ def __init__( if redis_url is None: try: # Attempt to use provided parameters or fallback to environment variables - host = host or os.environ['REDIS_HOST'] - port = port or os.environ['REDIS_PORT'] - password = password or os.environ['REDIS_PASSWORD'] + host = host or os.environ["REDIS_HOST"] + port = port or os.environ["REDIS_PORT"] + password = password or os.environ["REDIS_PASSWORD"] except KeyError as e: # Raise a more informative exception if any of the required keys are missing missing_var = e.args[0] - raise ValueError(f"Missing required Redis configuration: {missing_var}. " - f"Provide {missing_var} or redis_url.") from e + raise ValueError( + f"Missing required Redis configuration: {missing_var}. " + f"Provide {missing_var} or redis_url." + ) from e redis_url = f"redis://:{password}@{host}:{port}" @@ -114,7 +120,7 @@ def __init__( def _get_ttl(self, **kwargs) -> Optional[int]: """ Get the TTL (time-to-live) value for cache entries. - + Args: **kwargs: Keyword arguments that may contain a custom TTL @@ -125,33 +131,33 @@ def _get_ttl(self, **kwargs) -> Optional[int]: if ttl is not None: ttl = int(ttl) return ttl - + def _get_embedding(self, prompt: str) -> List[float]: """ Generate an embedding vector for the given prompt using the configured embedding model. - + Args: prompt: The text to generate an embedding for - + Returns: List[float]: The embedding vector """ # Create an embedding from prompt - embedding_response = litellm.embedding( + embedding_response = cast(EmbeddingResponse, litellm.embedding( model=self.embedding_model, input=prompt, cache={"no-store": True, "no-cache": True}, - ) + )) embedding = embedding_response["data"][0]["embedding"] return embedding def _get_cache_logic(self, cached_response: Any) -> Any: """ Process the cached response to prepare it for use. - + Args: cached_response: The raw cached response - + Returns: The processed cache response, or None if input was None """ @@ -171,13 +177,13 @@ def _get_cache_logic(self, cached_response: Any) -> Any: except (ValueError, SyntaxError) as e: print_verbose(f"Error parsing cached response: {str(e)}") return None - + return cached_response def set_cache(self, key: str, value: Any, **kwargs) -> None: """ Store a value in the semantic cache. - + Args: key: The cache key (not directly used in semantic caching) value: The response value to cache @@ -186,13 +192,14 @@ def set_cache(self, key: str, value: Any, **kwargs) -> None: """ print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}") + value_str: Optional[str] = None try: # Extract the prompt from messages messages = kwargs.get("messages", []) if not messages: print_verbose("No messages provided for semantic caching") return - + prompt = get_str_from_messages(messages) value_str = str(value) @@ -203,16 +210,18 @@ def set_cache(self, key: str, value: Any, **kwargs) -> None: else: self.llmcache.store(prompt, value_str) except Exception as e: - print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}") + print_verbose( + f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}" + ) def get_cache(self, key: str, **kwargs) -> Any: """ Retrieve a semantically similar cached response. - + Args: key: The cache key (not directly used in semantic caching) **kwargs: Additional arguments including 'messages' for the prompt - + Returns: The cached response if a semantically similar prompt is found, else None """ @@ -224,7 +233,7 @@ def get_cache(self, key: str, **kwargs) -> Any: if not messages: print_verbose("No messages provided for semantic cache lookup") return None - + prompt = get_str_from_messages(messages) # Check the cache for semantically similar prompts results = self.llmcache.check(prompt=prompt) @@ -236,12 +245,12 @@ def get_cache(self, key: str, **kwargs) -> Any: # Process the best matching result cache_hit = results[0] vector_distance = float(cache_hit["vector_distance"]) - + # Convert vector distance back to similarity score # For cosine distance: 0 = most similar, 2 = least similar # While similarity: 1 = most similar, 0 = least similar similarity = 1 - vector_distance - + cached_prompt = cache_hit["prompt"] cached_response = cache_hit["response"] @@ -251,19 +260,19 @@ def get_cache(self, key: str, **kwargs) -> Any: f"current prompt: {prompt}, " f"cached prompt: {cached_prompt}" ) - + return self._get_cache_logic(cached_response=cached_response) except Exception as e: print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}") - + async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: """ Asynchronously generate an embedding for the given prompt. - + Args: prompt: The text to generate an embedding for **kwargs: Additional arguments that may contain metadata - + Returns: List[float]: The embedding vector """ @@ -275,7 +284,7 @@ async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: if llm_model_list is not None else [] ) - + try: if llm_router is not None and self.embedding_model in router_model_names: # Use the router for embedding generation @@ -307,7 +316,7 @@ async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: """ Asynchronously store a value in the semantic cache. - + Args: key: The cache key (not directly used in semantic caching) value: The response value to cache @@ -322,13 +331,13 @@ async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: if not messages: print_verbose("No messages provided for semantic caching") return - + prompt = get_str_from_messages(messages) value_str = str(value) # Generate embedding for the value (response) to cache prompt_embedding = await self._get_async_embedding(prompt, **kwargs) - + # Get TTL and store in Redis semantic cache ttl = self._get_ttl(**kwargs) if ttl is not None: @@ -336,13 +345,13 @@ async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: prompt, value_str, vector=prompt_embedding, # Pass through custom embedding - ttl=ttl + ttl=ttl, ) else: await self.llmcache.astore( prompt, value_str, - vector=prompt_embedding # Pass through custom embedding + vector=prompt_embedding, # Pass through custom embedding ) except Exception as e: print_verbose(f"Error in async_set_cache: {str(e)}") @@ -350,11 +359,11 @@ async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: async def async_get_cache(self, key: str, **kwargs) -> Any: """ Asynchronously retrieve a semantically similar cached response. - + Args: key: The cache key (not directly used in semantic caching) **kwargs: Additional arguments including 'messages' for the prompt - + Returns: The cached response if a semantically similar prompt is found, else None """ @@ -367,21 +376,20 @@ async def async_get_cache(self, key: str, **kwargs) -> Any: print_verbose("No messages provided for semantic cache lookup") kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None - + prompt = get_str_from_messages(messages) - + # Generate embedding for the prompt prompt_embedding = await self._get_async_embedding(prompt, **kwargs) # Check the cache for semantically similar prompts - results = await self.llmcache.acheck( - prompt=prompt, - vector=prompt_embedding - ) + results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding) # handle results / cache hit if not results: - kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above?? + kwargs.setdefault("metadata", {})[ + "semantic-similarity" + ] = 0.0 # TODO why here but not above?? return None cache_hit = results[0] @@ -404,7 +412,7 @@ async def async_get_cache(self, key: str, **kwargs) -> Any: f"current prompt: {prompt}, " f"cached prompt: {cached_prompt}" ) - + return self._get_cache_logic(cached_response=cached_response) except Exception as e: print_verbose(f"Error in async_get_cache: {str(e)}") @@ -413,17 +421,19 @@ async def async_get_cache(self, key: str, **kwargs) -> Any: async def _index_info(self) -> Dict[str, Any]: """ Get information about the Redis index. - + Returns: Dict[str, Any]: Information about the Redis index """ aindex = await self.llmcache._get_async_index() return await aindex.info() - async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None: + async def async_set_cache_pipeline( + self, cache_list: List[Tuple[str, Any]], **kwargs + ) -> None: """ Asynchronously store multiple values in the semantic cache. - + Args: cache_list: List of (key, value) tuples to cache **kwargs: Additional arguments diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index f34359144353..a41fc364ab4b 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915 - For un-mapped Replicate models, the cost is calculated based on the total time used for the request. """ try: - call_type = _infer_call_type(call_type, completion_response) or "completion" if ( diff --git a/litellm/fine_tuning/main.py b/litellm/fine_tuning/main.py index b726a394c23b..09c070fffb14 100644 --- a/litellm/fine_tuning/main.py +++ b/litellm/fine_tuning/main.py @@ -138,7 +138,6 @@ def create_fine_tuning_job( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -360,7 +359,6 @@ def cancel_fine_tuning_job( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -522,7 +520,6 @@ def list_fine_tuning_jobs( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base diff --git a/litellm/integrations/SlackAlerting/batching_handler.py b/litellm/integrations/SlackAlerting/batching_handler.py index e35cf61d6300..fdce2e047938 100644 --- a/litellm/integrations/SlackAlerting/batching_handler.py +++ b/litellm/integrations/SlackAlerting/batching_handler.py @@ -19,7 +19,6 @@ def squash_payloads(queue): - squashed = {} if len(queue) == 0: return squashed diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index a2e626476030..50f0538cfd66 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -195,12 +195,15 @@ async def response_taking_too_long_callback( if self.alerting is None or self.alert_types is None: return - time_difference_float, model, api_base, messages = ( - self._response_taking_too_long_callback_helper( - kwargs=kwargs, - start_time=start_time, - end_time=end_time, - ) + ( + time_difference_float, + model, + api_base, + messages, + ) = self._response_taking_too_long_callback_helper( + kwargs=kwargs, + start_time=start_time, + end_time=end_time, ) if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: messages = "Message not logged. litellm.redact_messages_in_exceptions=True" @@ -819,9 +822,9 @@ async def region_outage_alerts( ### UNIQUE CACHE KEY ### cache_key = provider + region_name - outage_value: Optional[ProviderRegionOutageModel] = ( - await self.internal_usage_cache.async_get_cache(key=cache_key) - ) + outage_value: Optional[ + ProviderRegionOutageModel + ] = await self.internal_usage_cache.async_get_cache(key=cache_key) if ( getattr(exception, "status_code", None) is None @@ -1402,9 +1405,9 @@ async def send_alert( self.alert_to_webhook_url is not None and alert_type in self.alert_to_webhook_url ): - slack_webhook_url: Optional[Union[str, List[str]]] = ( - self.alert_to_webhook_url[alert_type] - ) + slack_webhook_url: Optional[ + Union[str, List[str]] + ] = self.alert_to_webhook_url[alert_type] elif self.default_webhook_url is not None: slack_webhook_url = self.default_webhook_url else: @@ -1768,7 +1771,6 @@ async def send_virtual_key_event_slack( - Team Created, Updated, Deleted """ try: - message = f"`{event_name}`\n" key_event_dict = key_event.model_dump() diff --git a/litellm/integrations/_types/open_inference.py b/litellm/integrations/_types/open_inference.py index b5076c0e4256..bcfabe9b7b14 100644 --- a/litellm/integrations/_types/open_inference.py +++ b/litellm/integrations/_types/open_inference.py @@ -283,4 +283,4 @@ class OpenInferenceSpanKindValues(Enum): class OpenInferenceMimeTypeValues(Enum): TEXT = "text/plain" - JSON = "application/json" \ No newline at end of file + JSON = "application/json" diff --git a/litellm/integrations/argilla.py b/litellm/integrations/argilla.py index 055ad90259af..a362ce7e4d7b 100644 --- a/litellm/integrations/argilla.py +++ b/litellm/integrations/argilla.py @@ -98,7 +98,6 @@ def get_credentials_from_env( argilla_dataset_name: Optional[str], argilla_base_url: Optional[str], ) -> ArgillaCredentialsObject: - _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") if _credentials_api_key is None: raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") diff --git a/litellm/integrations/arize/arize.py b/litellm/integrations/arize/arize.py index 7a0fb785a7b7..7aa412eef52e 100644 --- a/litellm/integrations/arize/arize.py +++ b/litellm/integrations/arize/arize.py @@ -26,7 +26,6 @@ class ArizeLogger(OpenTelemetry): - def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]): ArizeLogger.set_arize_attributes(span, kwargs, response_obj) return diff --git a/litellm/integrations/arize/arize_phoenix.py b/litellm/integrations/arize/arize_phoenix.py index d7b7d5812b2e..b2f77522241e 100644 --- a/litellm/integrations/arize/arize_phoenix.py +++ b/litellm/integrations/arize/arize_phoenix.py @@ -1,14 +1,17 @@ import os from typing import TYPE_CHECKING, Any -from litellm.integrations.arize import _utils + from litellm._logging import verbose_logger +from litellm.integrations.arize import _utils from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig if TYPE_CHECKING: - from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig - from litellm.types.integrations.arize import Protocol as _Protocol from opentelemetry.trace import Span as _Span + from litellm.types.integrations.arize import Protocol as _Protocol + + from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig + Protocol = _Protocol OpenTelemetryConfig = _OpenTelemetryConfig Span = _Span @@ -20,6 +23,7 @@ ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces" + class ArizePhoenixLogger: @staticmethod def set_arize_phoenix_attributes(span: Span, kwargs, response_obj): @@ -49,7 +53,7 @@ def get_arize_phoenix_config() -> ArizePhoenixConfig: protocol = "otlp_grpc" else: endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT - protocol = "otlp_http" + protocol = "otlp_http" verbose_logger.debug( f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}" ) @@ -57,17 +61,16 @@ def get_arize_phoenix_config() -> ArizePhoenixConfig: otlp_auth_headers = None # If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses # a slightly different auth header format than self hosted phoenix - if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: + if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: if api_key is None: - raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.") + raise ValueError( + "PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used." + ) otlp_auth_headers = f"api_key={api_key}" elif api_key is not None: # api_key/auth is optional for self hosted phoenix otlp_auth_headers = f"Authorization=Bearer {api_key}" return ArizePhoenixConfig( - otlp_auth_headers=otlp_auth_headers, - protocol=protocol, - endpoint=endpoint + otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint ) - diff --git a/litellm/integrations/athina.py b/litellm/integrations/athina.py index 705dc11f1d3c..49b9e9e68721 100644 --- a/litellm/integrations/athina.py +++ b/litellm/integrations/athina.py @@ -12,7 +12,10 @@ def __init__(self): "athina-api-key": self.athina_api_key, "Content-Type": "application/json", } - self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference" + self.athina_logging_url = ( + os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + + "/api/v1/log/inference" + ) self.additional_keys = [ "environment", "prompt_slug", diff --git a/litellm/integrations/azure_storage/azure_storage.py b/litellm/integrations/azure_storage/azure_storage.py index ddc46b117ffb..27f5e0e11230 100644 --- a/litellm/integrations/azure_storage/azure_storage.py +++ b/litellm/integrations/azure_storage/azure_storage.py @@ -50,12 +50,12 @@ def __init__( self.azure_storage_file_system: str = _azure_storage_file_system # Internal variables used for Token based authentication - self.azure_auth_token: Optional[str] = ( - None # the Azure AD token to use for Azure Storage API requests - ) - self.token_expiry: Optional[datetime] = ( - None # the expiry time of the currentAzure AD token - ) + self.azure_auth_token: Optional[ + str + ] = None # the Azure AD token to use for Azure Storage API requests + self.token_expiry: Optional[ + datetime + ] = None # the expiry time of the currentAzure AD token asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() @@ -153,7 +153,6 @@ async def async_upload_payload_to_azure_blob_storage( 3. Flush the data """ try: - if self.azure_storage_account_key: await self.upload_to_azure_data_lake_with_azure_account_key( payload=payload diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 281fbda01e36..0961eab02b8b 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -4,7 +4,7 @@ import copy import os from datetime import datetime -from typing import Optional, Dict +from typing import Dict, Optional import httpx from pydantic import BaseModel @@ -19,7 +19,9 @@ ) from litellm.utils import print_verbose -global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback) +global_braintrust_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback +) global_braintrust_sync_http_handler = HTTPHandler() API_BASE = "https://api.braintrustdata.com/v1" @@ -35,7 +37,9 @@ def get_utc_datetime(): class BraintrustLogger(CustomLogger): - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None: + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> None: super().__init__() self.validate_environment(api_key=api_key) self.api_base = api_base or API_BASE @@ -45,7 +49,9 @@ def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None "Authorization": "Bearer " + self.api_key, "Content-Type": "application/json", } - self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs + self._project_id_cache: Dict[ + str, str + ] = {} # Cache mapping project names to IDs def validate_environment(self, api_key: Optional[str]): """ @@ -71,7 +77,9 @@ def get_project_id_sync(self, project_name: str) -> str: try: response = global_braintrust_sync_http_handler.post( - f"{self.api_base}/project", headers=self.headers, json={"name": project_name} + f"{self.api_base}/project", + headers=self.headers, + json={"name": project_name}, ) project_dict = response.json() project_id = project_dict["id"] @@ -89,7 +97,9 @@ async def get_project_id_async(self, project_name: str) -> str: try: response = await global_braintrust_http_handler.post( - f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name} + f"{self.api_base}/project/register", + headers=self.headers, + json={"name": project_name}, ) project_dict = response.json() project_id = project_dict["id"] @@ -116,15 +126,21 @@ def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: if metadata is None: metadata = {} - proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + proxy_headers = ( + litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + ) for metadata_param_key in proxy_headers: if metadata_param_key.startswith("braintrust"): trace_param_key = metadata_param_key.replace("braintrust", "", 1) if trace_param_key in metadata: - verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header") + verbose_logger.warning( + f"Overwriting Braintrust `{trace_param_key}` from request header" + ) else: - verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header") + verbose_logger.debug( + f"Found Braintrust `{trace_param_key}` in request header" + ) metadata[trace_param_key] = proxy_headers.get(metadata_param_key) return metadata @@ -157,24 +173,35 @@ def log_success_event( # noqa: PLR0915 output = None choices = [] if response_obj is not None and ( - kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) ): output = None - elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): output = response_obj["choices"][0]["message"].json() choices = response_obj["choices"] - elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): output = response_obj.choices[0].text choices = response_obj.choices - elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): output = response_obj["data"] litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None metadata = self.add_metadata_from_header(litellm_params, metadata) clean_metadata = {} try: - metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata + metadata = copy.deepcopy( + metadata + ) # Avoid modifying the original metadata except Exception: new_metadata = {} for key, value in metadata.items(): @@ -192,7 +219,9 @@ def log_success_event( # noqa: PLR0915 project_id = metadata.get("project_id") if project_id is None: project_name = metadata.get("project_name") - project_id = self.get_project_id_sync(project_name) if project_name else None + project_id = ( + self.get_project_id_sync(project_name) if project_name else None + ) if project_id is None: if self.default_project_id is None: @@ -234,7 +263,8 @@ def log_success_event( # noqa: PLR0915 "completion_tokens": usage_obj.completion_tokens, "total_tokens": usage_obj.total_tokens, "total_cost": cost, - "time_to_first_token": end_time.timestamp() - start_time.timestamp(), + "time_to_first_token": end_time.timestamp() + - start_time.timestamp(), "start": start_time.timestamp(), "end": end_time.timestamp(), } @@ -255,7 +285,9 @@ def log_success_event( # noqa: PLR0915 request_data["metrics"] = metrics try: - print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}") + print_verbose( + f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}" + ) global_braintrust_sync_http_handler.post( url=f"{self.api_base}/project_logs/{project_id}/insert", json={"events": [request_data]}, @@ -276,20 +308,29 @@ async def async_log_success_event( # noqa: PLR0915 output = None choices = [] if response_obj is not None and ( - kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) ): output = None - elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): output = response_obj["choices"][0]["message"].json() choices = response_obj["choices"] - elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): output = response_obj.choices[0].text choices = response_obj.choices - elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): output = response_obj["data"] litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None metadata = self.add_metadata_from_header(litellm_params, metadata) clean_metadata = {} new_metadata = {} @@ -313,7 +354,11 @@ async def async_log_success_event( # noqa: PLR0915 project_id = metadata.get("project_id") if project_id is None: project_name = metadata.get("project_name") - project_id = await self.get_project_id_async(project_name) if project_name else None + project_id = ( + await self.get_project_id_async(project_name) + if project_name + else None + ) if project_id is None: if self.default_project_id is None: @@ -362,8 +407,14 @@ async def async_log_success_event( # noqa: PLR0915 api_call_start_time = kwargs.get("api_call_start_time") completion_start_time = kwargs.get("completion_start_time") - if api_call_start_time is not None and completion_start_time is not None: - metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp() + if ( + api_call_start_time is not None + and completion_start_time is not None + ): + metrics["time_to_first_token"] = ( + completion_start_time.timestamp() + - api_call_start_time.timestamp() + ) request_data = { "id": litellm_call_id, diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index 3cfdf82caba4..f9d4496c21f3 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -14,7 +14,6 @@ class CustomBatchLogger(CustomLogger): - def __init__( self, flush_lock: Optional[asyncio.Lock] = None, diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 4421664bfcd7..41a3800116e3 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -7,7 +7,6 @@ class CustomGuardrail(CustomLogger): - def __init__( self, guardrail_name: Optional[str] = None, diff --git a/litellm/integrations/datadog/datadog.py b/litellm/integrations/datadog/datadog.py index 4f4b05c84e8f..e9b6b6b1647a 100644 --- a/litellm/integrations/datadog/datadog.py +++ b/litellm/integrations/datadog/datadog.py @@ -233,7 +233,6 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time): pass async def _log_async_event(self, kwargs, response_obj, start_time, end_time): - dd_payload = self.create_datadog_logging_payload( kwargs=kwargs, response_obj=response_obj, diff --git a/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/litellm/integrations/gcs_bucket/gcs_bucket_base.py index 66995d8482f3..0ce845ecb2d8 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket_base.py @@ -125,9 +125,9 @@ async def get_gcs_logging_config( if kwargs is None: kwargs = {} - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params", None) bucket_name: str path_service_account: Optional[str] diff --git a/litellm/integrations/gcs_pubsub/pub_sub.py b/litellm/integrations/gcs_pubsub/pub_sub.py index 1b078df7bc83..bdaedcd90895 100644 --- a/litellm/integrations/gcs_pubsub/pub_sub.py +++ b/litellm/integrations/gcs_pubsub/pub_sub.py @@ -70,12 +70,13 @@ async def construct_request_headers(self) -> Dict[str, str]: """Construct authorization headers using Vertex AI auth""" from litellm import vertex_chat_completion - _auth_header, vertex_project = ( - await vertex_chat_completion._ensure_access_token_async( - credentials=self.path_service_account_json, - project_id=None, - custom_llm_provider="vertex_ai", - ) + ( + _auth_header, + vertex_project, + ) = await vertex_chat_completion._ensure_access_token_async( + credentials=self.path_service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", ) auth_header, _ = vertex_chat_completion._get_token_and_url( diff --git a/litellm/integrations/humanloop.py b/litellm/integrations/humanloop.py index fd3463f9e332..4651238af456 100644 --- a/litellm/integrations/humanloop.py +++ b/litellm/integrations/humanloop.py @@ -155,11 +155,7 @@ def get_chat_completion_prompt( prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: humanloop_api_key = dynamic_callback_params.get( "humanloop_api_key" ) or get_secret_str("HUMANLOOP_API_KEY") diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index f990a316c4ce..d0472ee6383e 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -471,9 +471,9 @@ def _log_langfuse_v2( # noqa: PLR0915 # we clean out all extra litellm metadata params before logging clean_metadata: Dict[str, Any] = {} if prompt_management_metadata is not None: - clean_metadata["prompt_management_metadata"] = ( - prompt_management_metadata - ) + clean_metadata[ + "prompt_management_metadata" + ] = prompt_management_metadata if isinstance(metadata, dict): for key, value in metadata.items(): # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy diff --git a/litellm/integrations/langfuse/langfuse_handler.py b/litellm/integrations/langfuse/langfuse_handler.py index aebe1461b019..f9d27f6cf00d 100644 --- a/litellm/integrations/langfuse/langfuse_handler.py +++ b/litellm/integrations/langfuse/langfuse_handler.py @@ -19,7 +19,6 @@ class LangFuseHandler: - @staticmethod def get_langfuse_logger_for_request( standard_callback_dynamic_params: StandardCallbackDynamicParams, @@ -87,7 +86,9 @@ def _return_global_langfuse_logger( if globalLangfuseLogger is not None: return globalLangfuseLogger - credentials_dict: Dict[str, Any] = ( + credentials_dict: Dict[ + str, Any + ] = ( {} ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache( diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index 1f4ca84db346..30f991ebd6db 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -172,11 +172,7 @@ async def async_get_chat_completion_prompt( prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: return self.get_chat_completion_prompt( model, messages, diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 1ef90c182201..0914150db944 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -75,7 +75,6 @@ def get_credentials_from_env( langsmith_project: Optional[str] = None, langsmith_base_url: Optional[str] = None, ) -> LangsmithCredentialsObject: - _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") if _credentials_api_key is None: raise Exception( @@ -443,9 +442,9 @@ def _get_credentials_to_use_for_request( Otherwise, use the default credentials. """ - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params", None) if standard_callback_dynamic_params is not None: credentials = self.get_credentials_from_env( langsmith_api_key=standard_callback_dynamic_params.get( @@ -481,7 +480,6 @@ def _send_batch(self): asyncio.run(self.async_send_batch()) def get_run_by_id(self, run_id): - langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] diff --git a/litellm/integrations/lunary.py b/litellm/integrations/lunary.py index fcd781e44e5e..b24a24e08812 100644 --- a/litellm/integrations/lunary.py +++ b/litellm/integrations/lunary.py @@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls): return None def clean_tool_call(tool_call): - serialized = { "type": tool_call.type, "id": tool_call.id, @@ -36,7 +35,6 @@ def clean_tool_call(tool_call): def parse_messages(input): - if input is None: return None diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py index 193d1c4ea2cd..e7a458accf9a 100644 --- a/litellm/integrations/mlflow.py +++ b/litellm/integrations/mlflow.py @@ -48,14 +48,17 @@ def _handle_success(self, kwargs, response_obj, start_time, end_time): def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): try: - from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools + from mlflow.tracing.utils import set_span_chat_messages # type: ignore + from mlflow.tracing.utils import set_span_chat_tools # type: ignore except ImportError: return inputs = self._construct_input(kwargs) input_messages = inputs.get("messages", []) - output_messages = [c.message.model_dump(exclude_none=True) - for c in getattr(response_obj, "choices", [])] + output_messages = [ + c.message.model_dump(exclude_none=True) + for c in getattr(response_obj, "choices", []) + ] if messages := [*input_messages, *output_messages]: set_span_chat_messages(span, messages) if tools := inputs.get("tools"): diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 1572eb81f5b1..177b2ae02bd7 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -46,7 +46,6 @@ @dataclass class OpenTelemetryConfig: - exporter: Union[str, SpanExporter] = "console" endpoint: Optional[str] = None headers: Optional[str] = None @@ -154,7 +153,6 @@ async def async_service_success_hook( end_time: Optional[Union[datetime, float]] = None, event_metadata: Optional[dict] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -215,7 +213,6 @@ async def async_service_failure_hook( end_time: Optional[Union[float, datetime]] = None, event_metadata: Optional[dict] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -353,9 +350,9 @@ def _add_dynamic_span_processor_if_needed(self, kwargs): """ from opentelemetry import trace - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params") - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params") if not standard_callback_dynamic_params: return @@ -722,7 +719,6 @@ def safe_set_attribute(self, span: Span, key: str, value: Any): span.set_attribute(key, primitive_value) def set_raw_request_attributes(self, span: Span, kwargs, response_obj): - kwargs.get("optional_params", {}) litellm_params = kwargs.get("litellm_params", {}) or {} custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown") @@ -907,7 +903,6 @@ async def async_management_endpoint_success_hook( logging_payload: ManagementEndpointLoggingPayload, parent_otel_span: Optional[Span] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -961,7 +956,6 @@ async def async_management_endpoint_failure_hook( logging_payload: ManagementEndpointLoggingPayload, parent_otel_span: Optional[Span] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode diff --git a/litellm/integrations/opik/opik.py b/litellm/integrations/opik/opik.py index 1f7f18f336b2..8cbfb9e6535e 100644 --- a/litellm/integrations/opik/opik.py +++ b/litellm/integrations/opik/opik.py @@ -185,7 +185,6 @@ async def async_send_batch(self): def _create_opik_payload( # noqa: PLR0915 self, kwargs, response_obj, start_time, end_time ) -> List[Dict]: - # Get metadata _litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params_metadata = _litellm_params.get("metadata", {}) or {} diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index d6e47b87ced5..5ac8c80eb30c 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -988,9 +988,9 @@ def set_llm_deployment_success_metrics( ): try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[StandardLoggingPayload] = ( - request_kwargs.get("standard_logging_object") - ) + standard_logging_payload: Optional[ + StandardLoggingPayload + ] = request_kwargs.get("standard_logging_object") if standard_logging_payload is None: return diff --git a/litellm/integrations/prompt_management_base.py b/litellm/integrations/prompt_management_base.py index 3fe3b31ed8bb..07b6720ffeda 100644 --- a/litellm/integrations/prompt_management_base.py +++ b/litellm/integrations/prompt_management_base.py @@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict): class PromptManagementBase(ABC): - @property @abstractmethod def integration_name(self) -> str: @@ -83,11 +82,7 @@ def get_chat_completion_prompt( prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: if not self.should_run_prompt_management( prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params ): diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index 5fcbab04b3ec..63d87c9bd908 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -21,11 +21,11 @@ class OpenAIResponse(Protocol[K, V]): # type: ignore # contains a (known) object attribute object: Literal["chat.completion", "edit", "text_completion"] - def __getitem__(self, key: K) -> V: ... # noqa + def __getitem__(self, key: K) -> V: + ... # noqa - def get( # noqa - self, key: K, default: Optional[V] = None - ) -> Optional[V]: ... # pragma: no cover + def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa + ... # pragma: no cover class OpenAIRequestResponseResolver: def __call__( diff --git a/litellm/litellm_core_utils/default_encoding.py b/litellm/litellm_core_utils/default_encoding.py index 05bf78a6a9cb..93b3132912cb 100644 --- a/litellm/litellm_core_utils/default_encoding.py +++ b/litellm/litellm_core_utils/default_encoding.py @@ -11,7 +11,9 @@ # Old way to access resources, which setuptools deprecated some time ago import pkg_resources # type: ignore - filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers") + filename = pkg_resources.resource_filename( + __name__, "litellm_core_utils/tokenizers" + ) os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv( "CUSTOM_TIKTOKEN_CACHE_DIR", filename diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 3565c4468c55..dcd3ae3a6427 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -239,9 +239,9 @@ def __init__( self.litellm_trace_id = litellm_trace_id self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[Any] = ( - [] - ) # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -452,18 +452,19 @@ def get_chat_completion_prompt( prompt_id: str, prompt_variables: Optional[dict], ) -> Tuple[str, List[AllMessageValues], dict]: - custom_logger = self.get_custom_logger_for_prompt_management(model) if custom_logger: - model, messages, non_default_params = ( - custom_logger.get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_variables=prompt_variables, - dynamic_callback_params=self.standard_callback_dynamic_params, - ) + ( + model, + messages, + non_default_params, + ) = custom_logger.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params=self.standard_callback_dynamic_params, ) self.messages = messages return model, messages, non_default_params @@ -541,12 +542,11 @@ def _pre_call(self, input, api_key, model=None, additional_args={}): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"]["api_base"] = ( - self._get_masked_api_base(additional_args.get("api_base", "")) - ) + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 - # Log the exact input to the LLM API litellm.error_logs["PRE_CALL"] = locals() try: @@ -568,19 +568,16 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR self.log_raw_request_response is True or log_raw_request_response is True ): - _litellm_params = self.model_call_details.get("litellm_params", {}) _metadata = _litellm_params.get("metadata", {}) or {} try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - - _metadata["raw_request"] = ( - "redacted by litellm. \ + _metadata[ + "raw_request" + ] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" - ) else: - curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), headers=additional_args.get("headers", {}), @@ -590,33 +587,33 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ignore_sensitive_headers=True, - ), - error=None, - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ignore_sensitive_headers=True, + ), + error=None, ) except Exception as e: - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - error=str(e), - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), ) traceback.print_exc() - _metadata["raw_request"] = ( - "Unable to Log \ + _metadata[ + "raw_request" + ] = "Unable to Log \ raw request: {}".format( - str(e) - ) + str(e) ) if self.logger_fn and callable(self.logger_fn): try: @@ -941,9 +938,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None try: @@ -968,9 +965,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None @@ -995,7 +992,6 @@ async def _response_cost_calculator_async( def should_run_callback( self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str ) -> bool: - if litellm.global_disable_no_log_param: return True @@ -1027,9 +1023,9 @@ def _success_handler_helper_fn( end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details["completion_start_time"] = ( - self.completion_start_time - ) + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit @@ -1083,39 +1079,39 @@ def _success_handler_helper_fn( "response_cost" ] else: - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=result) - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=result) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif isinstance(result, dict): # pass-through endpoints ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object - ) + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1154,7 +1150,6 @@ def success_handler( # noqa: PLR0915 standard_logging_object=kwargs.get("standard_logging_object", None), ) try: - ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response: Optional[ Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] @@ -1172,23 +1167,23 @@ def success_handler( # noqa: PLR0915 verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details["complete_streaming_response"] = ( - complete_streaming_response - ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, @@ -1207,7 +1202,6 @@ def success_handler( # noqa: PLR0915 ## LOGGING HOOK ## for callback in callbacks: if isinstance(callback, CustomLogger): - self.model_call_details, result = callback.logging_hook( kwargs=self.model_call_details, result=result, @@ -1538,10 +1532,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -1581,10 +1575,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] @@ -1659,7 +1653,6 @@ async def async_success_handler( # noqa: PLR0915 if self.call_type == CallTypes.aretrieve_batch.value and isinstance( result, LiteLLMBatch ): - response_cost, batch_usage, batch_models = await _handle_completed_batch( batch=result, custom_llm_provider=self.custom_llm_provider ) @@ -1692,9 +1685,9 @@ async def async_success_handler( # noqa: PLR0915 if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["async_complete_streaming_response"] = ( - complete_streaming_response - ) + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response try: if self.model_call_details.get("cache_hit", False) is True: self.model_call_details["response_cost"] = 0.0 @@ -1704,10 +1697,10 @@ async def async_success_handler( # noqa: PLR0915 model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details["response_cost"] = ( - self._response_cost_calculator( - result=complete_streaming_response - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response ) verbose_logger.debug( @@ -1720,16 +1713,16 @@ async def async_success_handler( # noqa: PLR0915 self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, @@ -1935,18 +1928,18 @@ def _failure_handler_helper_fn( ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, ) return start_time, end_time @@ -2084,7 +2077,6 @@ def failure_handler( # noqa: PLR0915 ) is not True ): # custom logger class - callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_config.endpoint, ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"space_key={arize_config.space_key},api_key={arize_config.api_key}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}" for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - arize_phoenix_config.otlp_auth_headers - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers for callback in _in_memory_loggers: if ( @@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"api_key={os.getenv('LANGTRACE_API_KEY')}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -3223,7 +3215,6 @@ def get_model_cost_information( custom_llm_provider: Optional[str], init_response_obj: Union[Any, BaseModel, dict], ) -> StandardLoggingModelInformation: - model_cost_name = _select_model_name_for_cost_calc( model=None, completion_response=init_response_obj, # type: ignore @@ -3286,7 +3277,6 @@ def get_final_response_obj( def get_additional_headers( additiona_headers: Optional[dict], ) -> Optional[StandardLoggingAdditionalHeaders]: - if additiona_headers is None: return None @@ -3322,10 +3312,10 @@ def get_hidden_params( for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params["additional_headers"] = ( - StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -3463,12 +3453,14 @@ def get_standard_logging_object_payload( ) # cleanup timestamps - start_time_float, end_time_float, completion_start_time_float = ( - StandardLoggingPayloadSetup.cleanup_timestamps( - start_time=start_time, - end_time=end_time, - completion_start_time=completion_start_time, - ) + ( + start_time_float, + end_time_float, + completion_start_time_float, + ) = StandardLoggingPayloadSetup.cleanup_timestamps( + start_time=start_time, + end_time=end_time, + completion_start_time=completion_start_time, ) response_time = StandardLoggingPayloadSetup.get_response_time( start_time_float=start_time_float, @@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload( saved_cache_cost: float = 0.0 if cache_hit is True: - id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id saved_cache_cost = ( logging_obj._response_cost_calculator( @@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[k] = ( - "scrubbed_by_litellm_for_sensitive_keys" - ) + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index d33af2a47726..f3f4ce6ef4db 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s class LiteLLMResponseObjectHandler: - @staticmethod def convert_to_image_response( response_object: dict, model_response_object: Optional[ImageResponse] = None, hidden_params: Optional[dict] = None, ) -> ImageResponse: - response_object.update({"hidden_params": hidden_params}) if model_response_object is None: @@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915 provider_specific_fields["thinking_blocks"] = thinking_blocks if reasoning_content: - provider_specific_fields["reasoning_content"] = ( - reasoning_content - ) + provider_specific_fields[ + "reasoning_content" + ] = reasoning_content message = Message( content=content, diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py index b7d8fc19d140..c96b4a3f5b85 100644 --- a/litellm/litellm_core_utils/model_param_helper.py +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -17,7 +17,6 @@ class ModelParamHelper: - @staticmethod def get_standard_logging_model_parameters( model_parameters: dict, diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index c8745f51192b..4170d3c1e16b 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -257,7 +257,6 @@ def _insert_assistant_continue_message( and message.get("role") == "user" # Current is user and messages[i + 1].get("role") == "user" ): # Next is user - # Insert assistant message continue_message = ( assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 28e09d7ac8af..3c89141b5ee0 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -1042,10 +1042,10 @@ def convert_to_gemini_tool_call_invoke( if tool_calls is not None: for tool in tool_calls: if "function" in tool: - gemini_function_call: Optional[VertexFunctionCall] = ( - _gemini_tool_call_invoke_helper( - function_call_params=tool["function"] - ) + gemini_function_call: Optional[ + VertexFunctionCall + ] = _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] ) if gemini_function_call is not None: _parts_list.append( @@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_element["cache_control"] = ( - _content_element["cache_control"] - ) + _anthropic_content_element[ + "cache_control" + ] = _content_element["cache_control"] user_content.append(_anthropic_content_element) elif m.get("type", "") == "text": m = cast(ChatCompletionTextObject, m) @@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_text_element["cache_control"] = ( - _content_element["cache_control"] - ) + _anthropic_content_text_element[ + "cache_control" + ] = _content_element["cache_control"] user_content.append(_anthropic_content_text_element) @@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915 "content" ] # don't pass empty text blocks. anthropic api raises errors. ): - _anthropic_text_content_element = AnthropicMessagesTextParam( type="text", text=assistant_content_block["content"], @@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915 msg_i += 1 if assistant_content: - new_messages.append({"role": "assistant", "content": assistant_content}) if msg_i == init_msg_i: # prevent infinite loops @@ -2245,7 +2243,6 @@ def _post_call_image_processing(response: httpx.Response) -> Tuple[str, str]: @staticmethod async def get_image_details_async(image_url) -> Tuple[str, str]: try: - client = get_async_httpx_client( llm_provider=httpxSpecialProvider.PromptFactory, params={"concurrent_limit": 1}, @@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message( for item in modified_content_block: # Check if the list is empty if item["type"] == "text": - if not item["text"].strip(): # Replace empty text with continue message _user_continue_message = ChatCompletionUserMessage( @@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915 assistant_content: List[BedrockContentBlock] = [] ## MERGE CONSECUTIVE ASSISTANT CONTENT ## while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - assistant_message_block = get_assistant_message_block_or_continue_message( message=messages[msg_i], assistant_continue_message=assistant_continue_message, @@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str: {"role": "user", "content": "{}".format(response_schema)} ] if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: - custom_prompt_details = litellm.custom_prompt_dict[ f"{model}/response_schema_prompt" ] # allow user to define custom response schema prompt by model diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py index aebd0496922c..e84c72044185 100644 --- a/litellm/litellm_core_utils/realtime_streaming.py +++ b/litellm/litellm_core_utils/realtime_streaming.py @@ -122,7 +122,6 @@ async def client_ack_messages(self): pass async def bidirectional_forward(self): - forward_task = asyncio.create_task(self.backend_to_client_send_messages()) try: await self.client_ack_messages() diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index 50e0e0b5755a..a62031a9c9bb 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params( handles boolean and string values of `turn_off_message_logging` """ - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - model_call_details.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = model_call_details.get("standard_callback_dynamic_params", None) if standard_callback_dynamic_params: _turn_off_message_logging = standard_callback_dynamic_params.get( "turn_off_message_logging" diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index 7800e5304fbb..23b9ec32fc70 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional, Set + from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH @@ -40,7 +41,10 @@ def is_sensitive_key(self, key: str) -> bool: return result def mask_dict( - self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH + self, + data: Dict[str, Any], + depth: int = 0, + max_depth: int = DEFAULT_MAX_RECURSE_DEPTH, ) -> Dict[str, Any]: if depth >= max_depth: return data diff --git a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py index 7a5ee3e41e3e..1ca2bfe45e99 100644 --- a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py +++ b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py @@ -104,7 +104,6 @@ def build_base_response(self, chunks: List[Dict[str, Any]]) -> ModelResponse: def get_combined_tool_content( self, tool_call_chunks: List[Dict[str, Any]] ) -> List[ChatCompletionMessageToolCall]: - argument_list: List[str] = [] delta = tool_call_chunks[0]["choices"][0]["delta"] id = None diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index a11e5af12b27..42106135ccc3 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -84,9 +84,9 @@ def __init__( self.system_fingerprint: Optional[str] = None self.received_finish_reason: Optional[str] = None - self.intermittent_finish_reason: Optional[str] = ( - None # finish reasons that show up mid-stream - ) + self.intermittent_finish_reason: Optional[ + str + ] = None # finish reasons that show up mid-stream self.special_tokens = [ "<|assistant|>", "<|system|>", @@ -814,7 +814,6 @@ def return_processed_chunk_logic( # noqa model_response: ModelResponseStream, response_obj: Dict[str, Any], ): - print_verbose( f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}" ) @@ -1008,7 +1007,6 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915 self.custom_llm_provider and self.custom_llm_provider in litellm._custom_providers ): - if self.received_finish_reason is not None: if "provider_specific_fields" not in chunk: raise StopIteration @@ -1379,9 +1377,9 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915 _json_delta = delta.model_dump() print_verbose(f"_json_delta: {_json_delta}") if "role" not in _json_delta or _json_delta["role"] is None: - _json_delta["role"] = ( - "assistant" # mistral's api returns role as None - ) + _json_delta[ + "role" + ] = "assistant" # mistral's api returns role as None if "tool_calls" in _json_delta and isinstance( _json_delta["tool_calls"], list ): @@ -1758,9 +1756,9 @@ async def __anext__(self): # noqa: PLR0915 chunk = next(self.completion_stream) if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk: Optional[ModelResponseStream] = ( - self.chunk_creator(chunk=chunk) - ) + processed_chunk: Optional[ + ModelResponseStream + ] = self.chunk_creator(chunk=chunk) print_verbose( f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" ) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index f2c5f390d7ab..7625292e6ee9 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -290,7 +290,6 @@ def completion( headers={}, client=None, ): - optional_params = copy.deepcopy(optional_params) stream = optional_params.pop("stream", None) json_mode: bool = optional_params.pop("json_mode", False) @@ -491,7 +490,6 @@ def check_empty_tool_call_args(self) -> bool: def _handle_usage( self, anthropic_usage_chunk: Union[dict, UsageDelta] ) -> AnthropicChatCompletionUsageBlock: - usage_block = AnthropicChatCompletionUsageBlock( prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), @@ -515,7 +513,9 @@ def _handle_usage( return usage_block - def _content_block_delta_helper(self, chunk: dict) -> Tuple[ + def _content_block_delta_helper( + self, chunk: dict + ) -> Tuple[ str, Optional[ChatCompletionToolCallChunk], List[ChatCompletionThinkingBlock], @@ -592,9 +592,12 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: Anthropic content chunk chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} """ - text, tool_use, thinking_blocks, provider_specific_fields = ( - self._content_block_delta_helper(chunk=chunk) - ) + ( + text, + tool_use, + thinking_blocks, + provider_specific_fields, + ) = self._content_block_delta_helper(chunk=chunk) if thinking_blocks: reasoning_content = self._handle_reasoning_content( thinking_blocks=thinking_blocks @@ -620,7 +623,6 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: "index": self.tool_index, } elif type_chunk == "content_block_stop": - ContentBlockStop(**chunk) # type: ignore # check if tool call content block is_empty = self.check_empty_tool_call_args() diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index dcbc6775dcc3..a8f36cdcad8a 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig): to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens: Optional[int] = ( - 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) - ) + max_tokens: Optional[ + int + ] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None @@ -104,7 +104,6 @@ def get_supported_openai_params(self, model: str): def get_json_schema_from_pydantic_object( self, response_format: Union[Any, Dict, None] ) -> Optional[dict]: - return type_to_response_format_param( response_format, ref_template="/$defs/{model}" ) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755 @@ -125,7 +124,6 @@ def get_anthropic_headers( is_vertex_request: bool = False, user_anthropic_beta_headers: Optional[List[str]] = None, ) -> dict: - betas = set() if prompt_caching_set: betas.add("prompt-caching-2024-07-31") @@ -300,7 +298,6 @@ def map_openai_params( model: str, drop_params: bool, ) -> dict: - is_thinking_enabled = self.is_thinking_enabled( non_default_params=non_default_params ) @@ -321,11 +318,11 @@ def map_openai_params( optional_params=optional_params, tools=tool_value ) if param == "tool_choice" or param == "parallel_tool_calls": - _tool_choice: Optional[AnthropicMessagesToolChoice] = ( - self._map_tool_choice( - tool_choice=non_default_params.get("tool_choice"), - parallel_tool_use=non_default_params.get("parallel_tool_calls"), - ) + _tool_choice: Optional[ + AnthropicMessagesToolChoice + ] = self._map_tool_choice( + tool_choice=non_default_params.get("tool_choice"), + parallel_tool_use=non_default_params.get("parallel_tool_calls"), ) if _tool_choice is not None: @@ -341,7 +338,6 @@ def map_openai_params( if param == "top_p": optional_params["top_p"] = value if param == "response_format" and isinstance(value, dict): - ignore_response_format_types = ["text"] if value["type"] in ignore_response_format_types: # value is a no-op continue @@ -470,9 +466,9 @@ def translate_system_message( text=system_message_block["content"], ) if "cache_control" in system_message_block: - anthropic_system_message_content["cache_control"] = ( - system_message_block["cache_control"] - ) + anthropic_system_message_content[ + "cache_control" + ] = system_message_block["cache_control"] anthropic_system_message_list.append( anthropic_system_message_content ) @@ -486,9 +482,9 @@ def translate_system_message( ) ) if "cache_control" in _content: - anthropic_system_message_content["cache_control"] = ( - _content["cache_control"] - ) + anthropic_system_message_content[ + "cache_control" + ] = _content["cache_control"] anthropic_system_message_list.append( anthropic_system_message_content @@ -597,7 +593,9 @@ def _transform_response_for_json_mode( ) return _message - def extract_response_content(self, completion_response: dict) -> Tuple[ + def extract_response_content( + self, completion_response: dict + ) -> Tuple[ str, Optional[List[Any]], Optional[List[ChatCompletionThinkingBlock]], @@ -693,9 +691,13 @@ def transform_response( reasoning_content: Optional[str] = None tool_calls: List[ChatCompletionToolCallChunk] = [] - text_content, citations, thinking_blocks, reasoning_content, tool_calls = ( - self.extract_response_content(completion_response=completion_response) - ) + ( + text_content, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + ) = self.extract_response_content(completion_response=completion_response) _message = litellm.Message( tool_calls=tool_calls, diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index 7a260b6f9492..5cbc0b5fd8b8 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig): to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens_to_sample: Optional[int] = ( - litellm.max_tokens - ) # anthropic requires a default + max_tokens_to_sample: Optional[ + int + ] = litellm.max_tokens # anthropic requires a default stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py index a7dfff74d9fc..099a2acdae8c 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -25,7 +25,6 @@ class AnthropicMessagesHandler: - @staticmethod async def _handle_anthropic_streaming( response: httpx.Response, @@ -74,19 +73,22 @@ async def anthropic_messages( """ # Use provided client or create a new one optional_params = GenericLiteLLMParams(**kwargs) - model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) - anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = ( - ProviderConfigManager.get_provider_anthropic_messages_config( - model=model, - provider=litellm.LlmProviders(_custom_llm_provider), - ) + anthropic_messages_provider_config: Optional[ + BaseAnthropicMessagesConfig + ] = ProviderConfigManager.get_provider_anthropic_messages_config( + model=model, + provider=litellm.LlmProviders(_custom_llm_provider), ) if anthropic_messages_provider_config is None: raise ValueError( diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 03c5cc09ebe5..aed813fdab8c 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -654,7 +654,6 @@ async def aembedding( ) -> EmbeddingResponse: response = None try: - openai_aclient = self.get_azure_openai_client( api_version=api_version, api_base=api_base, @@ -835,7 +834,6 @@ async def make_async_azure_httpx_request( "2023-10-01-preview", ] ): # CREATE + POLL for azure dall-e-2 calls - api_base = modify_url( original_url=api_base, new_path="/openai/images/generations:submit" ) @@ -867,7 +865,6 @@ async def make_async_azure_httpx_request( ) while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: - raise AzureOpenAIError( status_code=408, message="Operation polling timed out." ) @@ -935,7 +932,6 @@ def make_sync_azure_httpx_request( "2023-10-01-preview", ] ): # CREATE + POLL for azure dall-e-2 calls - api_base = modify_url( original_url=api_base, new_path="/openai/images/generations:submit" ) @@ -1199,7 +1195,6 @@ def audio_speech( client=None, litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: - max_retries = optional_params.pop("max_retries", 2) if aspeech is not None and aspeech is True: @@ -1253,7 +1248,6 @@ async def async_audio_speech( client=None, litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: - azure_client: AsyncAzureOpenAI = self.get_azure_openai_client( api_base=api_base, api_version=api_version, diff --git a/litellm/llms/azure/batches/handler.py b/litellm/llms/azure/batches/handler.py index 1b93c526d5a8..7fc6388ba874 100644 --- a/litellm/llms/azure/batches/handler.py +++ b/litellm/llms/azure/batches/handler.py @@ -50,15 +50,15 @@ def create_batch( client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -96,15 +96,15 @@ def retrieve_batch( client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -144,15 +144,15 @@ def cancel_batch( client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -183,15 +183,15 @@ def list_batches( client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 71092c8b993a..5d61557c210e 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -306,7 +306,6 @@ def initialize_azure_sdk_client( api_version: Optional[str], is_async: bool, ) -> dict: - azure_ad_token_provider: Optional[Callable[[], str]] = None # If we have api_key, then we have higher priority azure_ad_token = litellm_params.get("azure_ad_token") diff --git a/litellm/llms/azure/files/handler.py b/litellm/llms/azure/files/handler.py index d45ac9a315de..98407d05d5f0 100644 --- a/litellm/llms/azure/files/handler.py +++ b/litellm/llms/azure/files/handler.py @@ -46,16 +46,15 @@ def create_file( client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: - - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -95,15 +94,15 @@ def file_content( ) -> Union[ HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] ]: - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -145,15 +144,15 @@ def retrieve_file( client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -197,15 +196,15 @@ def delete_file( client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -251,15 +250,15 @@ def list_files( client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( diff --git a/litellm/llms/azure/fine_tuning/handler.py b/litellm/llms/azure/fine_tuning/handler.py index 3d7cc336fb50..429b83498965 100644 --- a/litellm/llms/azure/fine_tuning/handler.py +++ b/litellm/llms/azure/fine_tuning/handler.py @@ -25,14 +25,7 @@ def get_openai_client( _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[ - Union[ - OpenAI, - AsyncOpenAI, - AzureOpenAI, - AsyncAzureOpenAI, - ] - ]: + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: # Override to use Azure-specific client initialization if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): client = None diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 154f345537e8..a1fd24efa198 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -145,7 +145,6 @@ def _transform_messages( 2. If message contains an image or audio, send as is (user-intended) """ for message in messages: - # Do nothing if the message contains an image or audio if _audio_or_image_in_message_content(message): continue diff --git a/litellm/llms/azure_ai/embed/cohere_transformation.py b/litellm/llms/azure_ai/embed/cohere_transformation.py index 38b0dbbe2340..64433c21b61c 100644 --- a/litellm/llms/azure_ai/embed/cohere_transformation.py +++ b/litellm/llms/azure_ai/embed/cohere_transformation.py @@ -22,7 +22,6 @@ def __init__(self) -> None: pass def _map_azure_model_group(self, model: str) -> str: - if model == "offer-cohere-embed-multili-paygo": return "Cohere-embed-v3-multilingual" elif model == "offer-cohere-embed-english-paygo": diff --git a/litellm/llms/azure_ai/embed/handler.py b/litellm/llms/azure_ai/embed/handler.py index f33c979ca299..da39c5f3b893 100644 --- a/litellm/llms/azure_ai/embed/handler.py +++ b/litellm/llms/azure_ai/embed/handler.py @@ -17,7 +17,6 @@ class AzureAIEmbedding(OpenAIChatCompletion): - def _process_response( self, image_embedding_responses: Optional[List], @@ -145,7 +144,6 @@ async def async_embedding( api_base: Optional[str] = None, client=None, ) -> EmbeddingResponse: - ( image_embeddings_request, v1_embeddings_request, diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py index 842511f30dfa..4465e0d70a29 100644 --- a/litellm/llms/azure_ai/rerank/transformation.py +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig): """ Azure AI Rerank - Follows the same Spec as Cohere Rerank """ + def get_complete_url(self, api_base: Optional[str], model: str) -> str: if api_base is None: raise ValueError( diff --git a/litellm/llms/base.py b/litellm/llms/base.py index deced222ca8a..abc314bba05e 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -9,7 +9,6 @@ class BaseLLM: - _client_session: Optional[httpx.Client] = None def process_response( diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 45ea06b9e468..b4b120776cca 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -218,7 +218,6 @@ def _add_response_format_to_tools( json_schema = value["json_schema"]["schema"] if json_schema and not is_response_format_supported: - _tool_choice = ChatCompletionToolChoiceObjectParam( type="function", function=ChatCompletionToolChoiceFunctionParam( diff --git a/litellm/llms/base_llm/responses/transformation.py b/litellm/llms/base_llm/responses/transformation.py index 29555c55dad6..e98a579845f2 100644 --- a/litellm/llms/base_llm/responses/transformation.py +++ b/litellm/llms/base_llm/responses/transformation.py @@ -58,7 +58,6 @@ def map_openai_params( model: str, drop_params: bool, ) -> Dict: - pass @abstractmethod diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index a4230177b542..7f529c637a87 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -81,7 +81,6 @@ def make_sync_call( class BedrockConverseLLM(BaseAWSLLM): - def __init__(self) -> None: super().__init__() @@ -114,7 +113,6 @@ async def async_streaming( fake_stream: bool = False, json_mode: Optional[bool] = False, ) -> CustomStreamWrapper: - request_data = await litellm.AmazonConverseConfig()._async_transform_request( model=model, messages=messages, @@ -179,7 +177,6 @@ async def async_completion( headers: dict = {}, client: Optional[AsyncHTTPHandler] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: - request_data = await litellm.AmazonConverseConfig()._async_transform_request( model=model, messages=messages, @@ -265,7 +262,6 @@ def completion( # noqa: PLR0915 extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ): - ## SETUP ## stream = optional_params.pop("stream", None) unencoded_model_id = optional_params.pop("model_id", None) @@ -301,9 +297,9 @@ def completion( # noqa: PLR0915 aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) optional_params.pop("aws_region_name", None) - litellm_params["aws_region_name"] = ( - aws_region_name # [DO NOT DELETE] important for async calls - ) + litellm_params[ + "aws_region_name" + ] = aws_region_name # [DO NOT DELETE] important for async calls credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index ced9c469b329..05386c62b5c6 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -223,7 +223,6 @@ def map_openai_params( ) for param, value in non_default_params.items(): if param == "response_format" and isinstance(value, dict): - ignore_response_format_types = ["text"] if value["type"] in ignore_response_format_types: # value is a no-op continue @@ -715,9 +714,9 @@ def _transform_response( chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} content_str = "" tools: List[ChatCompletionToolCallChunk] = [] - reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = ( - None - ) + reasoningContentBlocks: Optional[ + List[BedrockConverseReasoningContentBlock] + ] = None if message is not None: for idx, content in enumerate(message["content"]): @@ -727,7 +726,6 @@ def _transform_response( if "text" in content: content_str += content["text"] if "toolUse" in content: - ## check tool name was formatted by litellm _response_tool_name = content["toolUse"]["name"] response_tool_name = get_bedrock_tool_name( @@ -754,12 +752,12 @@ def _transform_response( chat_completion_message["provider_specific_fields"] = { "reasoningContentBlocks": reasoningContentBlocks, } - chat_completion_message["reasoning_content"] = ( - self._transform_reasoning_content(reasoningContentBlocks) - ) - chat_completion_message["thinking_blocks"] = ( - self._transform_thinking_blocks(reasoningContentBlocks) - ) + chat_completion_message[ + "reasoning_content" + ] = self._transform_reasoning_content(reasoningContentBlocks) + chat_completion_message[ + "thinking_blocks" + ] = self._transform_thinking_blocks(reasoningContentBlocks) chat_completion_message["content"] = content_str if json_mode is True and tools is not None and len(tools) == 1: # to support 'json_schema' logic on bedrock models diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 5b02fd3158df..09bdd6357221 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -496,9 +496,9 @@ def process_response( # noqa: PLR0915 content=None, ) model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = ( - outputText # allow user to access raw anthropic tool calling response - ) + model_response._hidden_params[ + "original_response" + ] = outputText # allow user to access raw anthropic tool calling response if ( _is_function_call is True and stream is not None @@ -806,9 +806,9 @@ def completion( # noqa: PLR0915 ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) + inference_params[ + "stream" + ] = True # cohere requires stream = True in inference params data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "anthropic": if model.startswith("anthropic.claude-3"): @@ -1205,7 +1205,6 @@ def _get_model_id_for_llama_like_model( def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: - from botocore.loaders import Loader from botocore.model import ServiceModel @@ -1539,7 +1538,6 @@ def __init__( model: str, sync_stream: bool, ) -> None: - super().__init__(model=model) from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import ( AmazonDeepseekR1ResponseIterator, diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index 133eb659df89..d1212705d877 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -225,9 +225,9 @@ def transform_request( ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) + inference_params[ + "stream" + ] = True # cohere requires stream = True in inference params request_data = {"prompt": prompt, **inference_params} elif provider == "anthropic": return litellm.AmazonAnthropicClaude3Config().transform_request( @@ -311,7 +311,6 @@ def transform_response( # noqa: PLR0915 api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - try: completion_response = raw_response.json() except Exception: diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 4677a579ed96..f4a1170660d2 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str: class BedrockModelInfo(BaseLLMModelInfo): - global_config = AmazonBedrockGlobalConfig() all_global_regions = global_config.get_all_regions() diff --git a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py index 6c1147f24a5a..338029adc356 100644 --- a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py +++ b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py @@ -33,9 +33,9 @@ def map_openai_params( ) -> dict: for k, v in non_default_params.items(): if k == "dimensions": - optional_params["embeddingConfig"] = ( - AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) - ) + optional_params[ + "embeddingConfig" + ] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) return optional_params def _transform_request( @@ -58,7 +58,6 @@ def _transform_request( def _transform_response( self, response_list: List[dict], model: str ) -> EmbeddingResponse: - total_prompt_tokens = 0 transformed_responses: List[Embedding] = [] for index, response in enumerate(response_list): diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py index 33249f9af849..68ddf4ca880e 100644 --- a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -1,12 +1,16 @@ import types -from typing import List, Optional +from typing import Any, Dict, List, Optional from openai.types.image import Image from litellm.types.llms.bedrock import ( - AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse, - AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams, + AmazonNovaCanvasColorGuidedGenerationParams, AmazonNovaCanvasColorGuidedRequest, + AmazonNovaCanvasImageGenerationConfig, + AmazonNovaCanvasRequestBase, + AmazonNovaCanvasTextToImageParams, + AmazonNovaCanvasTextToImageRequest, + AmazonNovaCanvasTextToImageResponse, ) from litellm.types.utils import ImageResponse @@ -23,7 +27,7 @@ def get_config(cls): k: v for k, v in cls.__dict__.items() if not k.startswith("__") - and not isinstance( + and not isinstance( v, ( types.FunctionType, @@ -32,13 +36,12 @@ def get_config(cls): staticmethod, ), ) - and v is not None + and v is not None } @classmethod def get_supported_openai_params(cls, model: Optional[str] = None) -> List: - """ - """ + """ """ return ["n", "size", "quality"] @classmethod @@ -56,7 +59,7 @@ def _is_nova_model(cls, model: Optional[str] = None) -> bool: @classmethod def transform_request_body( - cls, text: str, optional_params: dict + cls, text: str, optional_params: dict ) -> AmazonNovaCanvasRequestBase: """ Transform the request body for Amazon Nova Canvas model @@ -65,18 +68,50 @@ def transform_request_body( image_generation_config = optional_params.pop("imageGenerationConfig", {}) image_generation_config = {**image_generation_config, **optional_params} if task_type == "TEXT_IMAGE": - text_to_image_params = image_generation_config.pop("textToImageParams", {}) - text_to_image_params = {"text" :text, **text_to_image_params} - text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params) - return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type, - imageGenerationConfig=image_generation_config) + text_to_image_params: Dict[str, Any] = image_generation_config.pop("textToImageParams", {}) + text_to_image_params = {"text": text, **text_to_image_params} + try: + text_to_image_params_typed = AmazonNovaCanvasTextToImageParams( + **text_to_image_params + ) + except Exception as e: + raise ValueError(f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}") + + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(**image_generation_config) + except Exception as e: + raise ValueError(f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}") + + return AmazonNovaCanvasTextToImageRequest( + textToImageParams=text_to_image_params_typed, + taskType=task_type, + imageGenerationConfig=image_generation_config_typed, + ) if task_type == "COLOR_GUIDED_GENERATION": - color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {}) - color_guided_generation_params = {"text": text, **color_guided_generation_params} - color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params) - return AmazonNovaCanvasColorGuidedRequest(taskType=task_type, - colorGuidedGenerationParams=color_guided_generation_params, - imageGenerationConfig=image_generation_config) + color_guided_generation_params: Dict[str, Any] = image_generation_config.pop( + "colorGuidedGenerationParams", {} + ) + color_guided_generation_params = { + "text": text, + **color_guided_generation_params, + } + try: + color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams( + **color_guided_generation_params + ) + except Exception as e: + raise ValueError(f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}") + + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(**image_generation_config) + except Exception as e: + raise ValueError(f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}") + + return AmazonNovaCanvasColorGuidedRequest( + taskType=task_type, + colorGuidedGenerationParams=color_guided_generation_params_typed, + imageGenerationConfig=image_generation_config_typed, + ) raise NotImplementedError(f"Task type {task_type} is not supported") @classmethod @@ -87,7 +122,9 @@ def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> d _size = non_default_params.get("size") if _size is not None: width, height = _size.split("x") - optional_params["width"], optional_params["height"] = int(width), int(height) + optional_params["width"], optional_params["height"] = int(width), int( + height + ) if non_default_params.get("n") is not None: optional_params["numberOfImages"] = non_default_params.get("n") if non_default_params.get("quality") is not None: @@ -99,7 +136,7 @@ def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> d @classmethod def transform_response_dict_to_openai_response( - cls, model_response: ImageResponse, response_dict: dict + cls, model_response: ImageResponse, response_dict: dict ) -> ImageResponse: """ Transform the response dict to the OpenAI response diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index 8f7762e547a4..27258aa20f45 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -267,7 +267,11 @@ def _get_request_body( **inference_params, } elif provider == "amazon": - return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params)) + return dict( + litellm.AmazonNovaCanvasConfig.transform_request_body( + text=prompt, optional_params=optional_params + ) + ) else: raise BedrockError( status_code=422, message=f"Unsupported model={model}, passed in" @@ -303,8 +307,11 @@ def _transform_response_dict_to_openai_response( config_class = ( litellm.AmazonStability3Config if litellm.AmazonStability3Config._is_stability_3_model(model=model) - else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) - else litellm.AmazonStabilityConfig + else ( + litellm.AmazonNovaCanvasConfig + if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) + else litellm.AmazonStabilityConfig + ) ) config_class.transform_response_dict_to_openai_response( model_response=model_response, diff --git a/litellm/llms/bedrock/rerank/handler.py b/litellm/llms/bedrock/rerank/handler.py index cd8be6912c6a..f5a532bec157 100644 --- a/litellm/llms/bedrock/rerank/handler.py +++ b/litellm/llms/bedrock/rerank/handler.py @@ -60,7 +60,6 @@ def rerank( extra_headers: Optional[dict] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - request_data = RerankRequest( model=model, query=query, diff --git a/litellm/llms/bedrock/rerank/transformation.py b/litellm/llms/bedrock/rerank/transformation.py index a5380febe9cd..be8250a96719 100644 --- a/litellm/llms/bedrock/rerank/transformation.py +++ b/litellm/llms/bedrock/rerank/transformation.py @@ -29,7 +29,6 @@ class BedrockRerankConfig: - def _transform_sources( self, documents: List[Union[str, dict]] ) -> List[BedrockRerankSource]: diff --git a/litellm/llms/codestral/completion/handler.py b/litellm/llms/codestral/completion/handler.py index fc6d2886a99d..555f7fccfb7c 100644 --- a/litellm/llms/codestral/completion/handler.py +++ b/litellm/llms/codestral/completion/handler.py @@ -314,7 +314,6 @@ def completion( return _response ### SYNC COMPLETION else: - response = litellm.module_level_client.post( url=completion_url, headers=headers, @@ -352,13 +351,11 @@ async def async_completion( logger_fn=None, headers={}, ) -> TextCompletionResponse: - async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, params={"timeout": timeout}, ) try: - response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) diff --git a/litellm/llms/codestral/completion/transformation.py b/litellm/llms/codestral/completion/transformation.py index 5955e91debda..fc7b6f5dbb29 100644 --- a/litellm/llms/codestral/completion/transformation.py +++ b/litellm/llms/codestral/completion/transformation.py @@ -78,7 +78,6 @@ def map_openai_params( return optional_params def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: - text = "" is_finished = False finish_reason = None diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 3ceec2dbba4c..fbaedca8f63f 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -180,7 +180,6 @@ def transform_request( litellm_params: dict, headers: dict, ) -> dict: - ## Load Config for k, v in litellm.CohereChatConfig.get_config().items(): if ( @@ -222,7 +221,6 @@ def transform_response( api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - try: raw_response_json = raw_response.json() model_response.choices[0].message.content = raw_response_json["text"] # type: ignore diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index e7f22ea72ada..7a25bf7e5410 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -56,7 +56,6 @@ async def async_embedding( encoding: Callable, client: Optional[AsyncHTTPHandler] = None, ): - ## LOGGING logging_obj.pre_call( input=input, diff --git a/litellm/llms/cohere/embed/transformation.py b/litellm/llms/cohere/embed/transformation.py index 22e157a0fd9e..837dd5e006e5 100644 --- a/litellm/llms/cohere/embed/transformation.py +++ b/litellm/llms/cohere/embed/transformation.py @@ -72,7 +72,6 @@ def _transform_request( return transformed_request def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage: - input_tokens = 0 text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens") @@ -111,7 +110,6 @@ def _transform_response( encoding: Any, input: list, ) -> EmbeddingResponse: - response_json = response.json() ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/cohere/rerank/transformation.py b/litellm/llms/cohere/rerank/transformation.py index f3624d9216c7..22782c13008f 100644 --- a/litellm/llms/cohere/rerank/transformation.py +++ b/litellm/llms/cohere/rerank/transformation.py @@ -148,4 +148,4 @@ def transform_rerank_response( def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: - return CohereError(message=error_message, status_code=status_code) \ No newline at end of file + return CohereError(message=error_message, status_code=status_code) diff --git a/litellm/llms/cohere/rerank_v2/transformation.py b/litellm/llms/cohere/rerank_v2/transformation.py index a93cb982a710..74e760460d05 100644 --- a/litellm/llms/cohere/rerank_v2/transformation.py +++ b/litellm/llms/cohere/rerank_v2/transformation.py @@ -3,6 +3,7 @@ from litellm.llms.cohere.rerank.transformation import CohereRerankConfig from litellm.types.rerank import OptionalRerankParams, RerankRequest + class CohereRerankV2Config(CohereRerankConfig): """ Reference: https://docs.cohere.com/v2/reference/rerank @@ -77,4 +78,4 @@ def transform_rerank_request( return_documents=optional_rerank_params.get("return_documents", None), max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None), ) - return rerank_request.model_dump(exclude_none=True) \ No newline at end of file + return rerank_request.model_dump(exclude_none=True) diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index c865fee17e3e..9568ce718501 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -32,7 +32,6 @@ class BaseLLMAIOHTTPHandler: - def __init__(self): self.client_session: Optional[aiohttp.ClientSession] = None @@ -110,7 +109,6 @@ def _make_common_sync_call( content: Any = None, params: Optional[dict] = None, ) -> httpx.Response: - max_retry_on_unprocessable_entity_error = ( provider_config.max_retry_on_unprocessable_entity_error ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 34d70434d542..23d7fe4b4d4a 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -114,7 +114,6 @@ def create_client( event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], ssl_verify: Optional[VerifyTypes] = None, ) -> httpx.AsyncClient: - # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem if ssl_verify is None: @@ -590,7 +589,6 @@ def patch( timeout: Optional[Union[float, httpx.Timeout]] = None, ): try: - if timeout is not None: req = self.client.build_request( "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore @@ -609,7 +607,6 @@ def patch( llm_provider="litellm-httpx-handler", ) except httpx.HTTPStatusError as e: - if stream is True: setattr(e, "message", mask_sensitive_info(e.response.read())) setattr(e, "text", mask_sensitive_info(e.response.read())) @@ -635,7 +632,6 @@ def put( timeout: Optional[Union[float, httpx.Timeout]] = None, ): try: - if timeout is not None: req = self.client.build_request( "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 872626c74797..12736640f177 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -41,7 +41,6 @@ class BaseLLMHTTPHandler: - async def _make_common_async_call( self, async_httpx_client: AsyncHTTPHandler, @@ -109,7 +108,6 @@ def _make_common_sync_call( logging_obj: LiteLLMLoggingObj, stream: bool = False, ) -> httpx.Response: - max_retry_on_unprocessable_entity_error = ( provider_config.max_retry_on_unprocessable_entity_error ) @@ -599,7 +597,6 @@ def embedding( aembedding: bool = False, headers={}, ) -> EmbeddingResponse: - provider_config = ProviderConfigManager.get_provider_embedding_config( model=model, provider=litellm.LlmProviders(custom_llm_provider) ) @@ -742,7 +739,6 @@ def rerank( api_base: Optional[str] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - # get config from model, custom llm provider headers = provider_config.validate_environment( api_key=api_key, @@ -828,7 +824,6 @@ async def arerank( timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - if client is None or not isinstance(client, AsyncHTTPHandler): async_httpx_client = get_async_httpx_client( llm_provider=litellm.LlmProviders(custom_llm_provider) diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py index e8481e25b2c8..76bd281d4d2f 100644 --- a/litellm/llms/databricks/common_utils.py +++ b/litellm/llms/databricks/common_utils.py @@ -16,9 +16,9 @@ def _get_databricks_credentials( api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" if api_key is None: - databricks_auth_headers: dict[str, str] = ( - databricks_client.config.authenticate() - ) + databricks_auth_headers: dict[ + str, str + ] = databricks_client.config.authenticate() headers = {**databricks_auth_headers, **headers} return api_base, headers diff --git a/litellm/llms/databricks/embed/transformation.py b/litellm/llms/databricks/embed/transformation.py index 53e3b30dd213..a113a349cc67 100644 --- a/litellm/llms/databricks/embed/transformation.py +++ b/litellm/llms/databricks/embed/transformation.py @@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig: Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task """ - instruction: Optional[str] = ( - None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries - ) + instruction: Optional[ + str + ] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries def __init__(self, instruction: Optional[str] = None) -> None: locals_ = locals().copy() diff --git a/litellm/llms/databricks/streaming_utils.py b/litellm/llms/databricks/streaming_utils.py index 2db53df90838..eebe31828810 100644 --- a/litellm/llms/databricks/streaming_utils.py +++ b/litellm/llms/databricks/streaming_utils.py @@ -55,7 +55,6 @@ def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None) if usage_chunk is not None: - usage = ChatCompletionUsageBlock( prompt_tokens=usage_chunk.prompt_tokens, completion_tokens=usage_chunk.completion_tokens, diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py index 90720a77f7f0..20599e3994b5 100644 --- a/litellm/llms/deepgram/audio_transcription/transformation.py +++ b/litellm/llms/deepgram/audio_transcription/transformation.py @@ -126,9 +126,9 @@ def transform_audio_transcription_response( # Add additional metadata matching OpenAI format response["task"] = "transcribe" - response["language"] = ( - "english" # Deepgram auto-detects but doesn't return language - ) + response[ + "language" + ] = "english" # Deepgram auto-detects but doesn't return language response["duration"] = response_json["metadata"]["duration"] # Transform words to match OpenAI format diff --git a/litellm/llms/deepseek/chat/transformation.py b/litellm/llms/deepseek/chat/transformation.py index 180cf7dc6938..fe70ebe77ef8 100644 --- a/litellm/llms/deepseek/chat/transformation.py +++ b/litellm/llms/deepseek/chat/transformation.py @@ -14,7 +14,6 @@ class DeepSeekChatConfig(OpenAIGPTConfig): - def _transform_messages( self, messages: List[AllMessageValues], model: str ) -> List[AllMessageValues]: diff --git a/litellm/llms/deprecated_providers/aleph_alpha.py b/litellm/llms/deprecated_providers/aleph_alpha.py index 81ad13464141..4cfede2a1b9c 100644 --- a/litellm/llms/deprecated_providers/aleph_alpha.py +++ b/litellm/llms/deprecated_providers/aleph_alpha.py @@ -77,9 +77,9 @@ class AlephAlphaConfig: - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. """ - maximum_tokens: Optional[int] = ( - litellm.max_tokens - ) # aleph alpha requires max tokens + maximum_tokens: Optional[ + int + ] = litellm.max_tokens # aleph alpha requires max tokens minimum_tokens: Optional[int] = None echo: Optional[bool] = None temperature: Optional[int] = None diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index 1c82f24ac088..4def12adb757 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -88,7 +88,6 @@ def map_openai_params( model: str, drop_params: bool, ) -> dict: - supported_openai_params = self.get_supported_openai_params(model=model) is_tools_set = any( param == "tools" and value is not None @@ -104,7 +103,6 @@ def map_openai_params( # pass through the value of tool choice optional_params["tool_choice"] = value elif param == "response_format": - if ( is_tools_set ): # fireworks ai doesn't support tools and response_format together @@ -223,7 +221,6 @@ def _get_openai_compatible_provider_info( return api_base, dynamic_api_key def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None): - api_base, api_key = self._get_openai_compatible_provider_info( api_base=api_base, api_key=api_key ) diff --git a/litellm/llms/gemini/chat/transformation.py b/litellm/llms/gemini/chat/transformation.py index fbc1916dccef..0d5956122e6d 100644 --- a/litellm/llms/gemini/chat/transformation.py +++ b/litellm/llms/gemini/chat/transformation.py @@ -90,7 +90,6 @@ def map_openai_params( model: str, drop_params: bool, ) -> Dict: - if litellm.vertex_ai_safety_settings is not None: optional_params["safety_settings"] = litellm.vertex_ai_safety_settings return super().map_openai_params( diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py index 7f266c05367a..4c3357a500dd 100644 --- a/litellm/llms/gemini/common_utils.py +++ b/litellm/llms/gemini/common_utils.py @@ -25,7 +25,6 @@ def get_base_model(model: str) -> Optional[str]: def get_models( self, api_key: Optional[str] = None, api_base: Optional[str] = None ) -> List[str]: - api_base = GeminiModelInfo.get_api_base(api_base) api_key = GeminiModelInfo.get_api_key(api_key) if api_base is None or api_key is None: diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index 5b24f7d1124c..b0ee69bed2a9 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -18,7 +18,6 @@ class GroqChatConfig(OpenAIGPTConfig): - frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None diff --git a/litellm/llms/groq/stt/transformation.py b/litellm/llms/groq/stt/transformation.py index c4dbd8d0caf2..b467fab14f6e 100644 --- a/litellm/llms/groq/stt/transformation.py +++ b/litellm/llms/groq/stt/transformation.py @@ -9,7 +9,6 @@ class GroqSTTConfig: - frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index 858fda473ea8..082960b2c2d2 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -40,17 +40,17 @@ class HuggingfaceChatConfig(BaseConfig): Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate """ - hf_task: Optional[hf_tasks] = ( - None # litellm-specific param, used to know the api spec to use when calling huggingface api - ) + hf_task: Optional[ + hf_tasks + ] = None # litellm-specific param, used to know the api spec to use when calling huggingface api best_of: Optional[int] = None decoder_input_details: Optional[bool] = None details: Optional[bool] = True # enables returning logprobs + best of max_new_tokens: Optional[int] = None repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = ( - False # by default don't return the input as part of the output - ) + return_full_text: Optional[ + bool + ] = False # by default don't return the input as part of the output seed: Optional[int] = None temperature: Optional[float] = None top_k: Optional[int] = None @@ -120,9 +120,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": @@ -362,9 +362,9 @@ def validate_environment( "content-type": "application/json", } if api_key is not None: - default_headers["Authorization"] = ( - f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens - ) + default_headers[ + "Authorization" + ] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens headers = {**headers, **default_headers} return headers diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py index 5f2b8d71bca0..418d13b3448a 100644 --- a/litellm/llms/maritalk.py +++ b/litellm/llms/maritalk.py @@ -17,7 +17,6 @@ def __init__( class MaritalkConfig(OpenAIGPTConfig): - def __init__( self, frequency_penalty: Optional[float] = None, diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index b4db95cfa1cf..b007bbb2bc01 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -89,9 +89,9 @@ class OllamaConfig(BaseConfig): repeat_penalty: Optional[float] = None temperature: Optional[float] = None seed: Optional[int] = None - stop: Optional[list] = ( - None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 - ) + stop: Optional[ + list + ] = None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 tfs_z: Optional[float] = None num_predict: Optional[int] = None top_k: Optional[int] = None diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index c765f9797940..1fde23c9c203 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -391,7 +391,6 @@ def get_model_response_iterator( class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator): - def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: return ModelResponseStream( diff --git a/litellm/llms/openai/completion/handler.py b/litellm/llms/openai/completion/handler.py index 2e60f55b57cb..fa31c487cd27 100644 --- a/litellm/llms/openai/completion/handler.py +++ b/litellm/llms/openai/completion/handler.py @@ -220,7 +220,6 @@ def streaming( client=None, organization=None, ): - if client is None: openai_client = OpenAI( api_key=api_key, diff --git a/litellm/llms/openai/completion/transformation.py b/litellm/llms/openai/completion/transformation.py index 1aef72d3fa75..43fbc1f21922 100644 --- a/litellm/llms/openai/completion/transformation.py +++ b/litellm/llms/openai/completion/transformation.py @@ -111,9 +111,9 @@ def convert_to_chat_model_response_object( if "model" in response_object: model_response_object.model = response_object["model"] - model_response_object._hidden_params["original_response"] = ( - response_object # track original response, if users make a litellm.text_completion() request, we can return the original response - ) + model_response_object._hidden_params[ + "original_response" + ] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response return model_response_object except Exception as e: raise e diff --git a/litellm/llms/openai/fine_tuning/handler.py b/litellm/llms/openai/fine_tuning/handler.py index 97b237c75725..2b697f85d2d2 100644 --- a/litellm/llms/openai/fine_tuning/handler.py +++ b/litellm/llms/openai/fine_tuning/handler.py @@ -28,14 +28,7 @@ def get_openai_client( _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[ - Union[ - OpenAI, - AsyncOpenAI, - AzureOpenAI, - AsyncAzureOpenAI, - ] - ]: + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: received_args = locals() openai_client: Optional[ Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index deb70b481e94..0545542eadc8 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -266,7 +266,6 @@ def transform_response( api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - logging_obj.post_call(original_response=raw_response.text) logging_obj.model_call_details["response_headers"] = raw_response.headers final_response_obj = cast( @@ -320,7 +319,6 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): - def __init__(self) -> None: super().__init__() @@ -513,7 +511,6 @@ def completion( # type: ignore # noqa: PLR0915 custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, ): - super().completion() try: fake_stream: bool = False @@ -553,7 +550,6 @@ def completion( # type: ignore # noqa: PLR0915 for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message - if provider_config is not None: data = provider_config.transform_request( model=model, @@ -649,13 +645,14 @@ def completion( # type: ignore # noqa: PLR0915 }, ) - headers, response = ( - self.make_sync_openai_chat_completion_request( - openai_client=openai_client, - data=data, - timeout=timeout, - logging_obj=logging_obj, - ) + ( + headers, + response, + ) = self.make_sync_openai_chat_completion_request( + openai_client=openai_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, ) logging_obj.model_call_details["response_headers"] = headers @@ -763,7 +760,6 @@ async def acompletion( for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message - try: openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, @@ -973,7 +969,6 @@ async def async_streaming( except ( Exception ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. - if isinstance(e, OpenAIError): raise e @@ -1246,7 +1241,6 @@ async def aimage_generation( ): response = None try: - openai_aclient = self._get_openai_client( is_async=True, api_key=api_key, @@ -1333,7 +1327,6 @@ def image_generation( ) return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: - ## LOGGING logging_obj.post_call( input=prompt, @@ -1372,7 +1365,6 @@ def audio_speech( aspeech: Optional[bool] = None, client=None, ) -> HttpxBinaryResponseContent: - if aspeech is not None and aspeech is True: return self.async_audio_speech( model=model, @@ -1419,7 +1411,6 @@ async def async_audio_speech( timeout: Union[float, httpx.Timeout], client=None, ) -> HttpxBinaryResponseContent: - openai_client = cast( AsyncOpenAI, self._get_openai_client( diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py index 5a7d6481a86b..2d3d611dac55 100644 --- a/litellm/llms/openai/transcriptions/whisper_transformation.py +++ b/litellm/llms/openai/transcriptions/whisper_transformation.py @@ -81,9 +81,9 @@ def transform_audio_transcription_request( if "response_format" not in data or ( data["response_format"] == "text" or data["response_format"] == "json" ): - data["response_format"] = ( - "verbose_json" # ensures 'duration' is received - used for cost calculation - ) + data[ + "response_format" + ] = "verbose_json" # ensures 'duration' is received - used for cost calculation return data diff --git a/litellm/llms/openrouter/chat/transformation.py b/litellm/llms/openrouter/chat/transformation.py index ab4d3c52b99e..0b47167524e9 100644 --- a/litellm/llms/openrouter/chat/transformation.py +++ b/litellm/llms/openrouter/chat/transformation.py @@ -19,7 +19,6 @@ class OpenrouterConfig(OpenAIGPTConfig): - def map_openai_params( self, non_default_params: dict, @@ -42,9 +41,9 @@ def map_openai_params( extra_body["models"] = models if route is not None: extra_body["route"] = route - mapped_openai_params["extra_body"] = ( - extra_body # openai client supports `extra_body` param - ) + mapped_openai_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param return mapped_openai_params def get_error_class( @@ -70,7 +69,6 @@ def get_model_response_iterator( class OpenRouterChatCompletionStreamingHandler(BaseModelResponseIterator): - def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: new_choices = [] diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index 08ec15de33b9..a9e37d27fc97 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -37,9 +37,9 @@ class PetalsConfig(BaseConfig): """ max_length: Optional[int] = None - max_new_tokens: Optional[int] = ( - litellm.max_tokens - ) # petals requires max tokens to be set + max_new_tokens: Optional[ + int + ] = litellm.max_tokens # petals requires max tokens to be set do_sample: Optional[bool] = None temperature: Optional[float] = None top_k: Optional[int] = None diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py index 43f4b0674505..cd80fa53e44b 100644 --- a/litellm/llms/predibase/chat/handler.py +++ b/litellm/llms/predibase/chat/handler.py @@ -394,7 +394,6 @@ async def async_completion( logger_fn=None, headers={}, ) -> ModelResponse: - async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.PREDIBASE, params={"timeout": timeout}, diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index f574238696d6..f1a2163d24d2 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -30,9 +30,9 @@ class PredibaseConfig(BaseConfig): 256 # openai default - requests hang if max_new_tokens not given ) repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = ( - False # by default don't return the input as part of the output - ) + return_full_text: Optional[ + bool + ] = False # by default don't return the input as part of the output seed: Optional[int] = None stop: Optional[List[str]] = None temperature: Optional[float] = None @@ -99,9 +99,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index f52eb2ee05a2..526f376b8910 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -244,7 +244,6 @@ async def async_completion( print_verbose, headers: dict, ) -> Union[ModelResponse, CustomStreamWrapper]: - prediction_url = replicate_config.get_complete_url( api_base=api_base, model=model, diff --git a/litellm/llms/sagemaker/chat/handler.py b/litellm/llms/sagemaker/chat/handler.py index c827a8a5f7e6..b86cda7aeafb 100644 --- a/litellm/llms/sagemaker/chat/handler.py +++ b/litellm/llms/sagemaker/chat/handler.py @@ -13,7 +13,6 @@ class SagemakerChatHandler(BaseAWSLLM): - def _load_credentials( self, optional_params: dict, @@ -128,7 +127,6 @@ def completion( headers: dict = {}, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) diff --git a/litellm/llms/sagemaker/common_utils.py b/litellm/llms/sagemaker/common_utils.py index 9884f420c3ac..ce0c6c95062f 100644 --- a/litellm/llms/sagemaker/common_utils.py +++ b/litellm/llms/sagemaker/common_utils.py @@ -34,7 +34,6 @@ def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: def _chunk_parser_messages_api( self, chunk_data: dict ) -> StreamingChatCompletionChunk: - openai_chunk = StreamingChatCompletionChunk(**chunk_data) return openai_chunk @@ -192,7 +191,6 @@ def _parse_message_from_event(self, event) -> Optional[str]: def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: - from botocore.loaders import Loader from botocore.model import ServiceModel diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index 909caf73c3db..fcae9d6c03d3 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, cast import httpx @@ -35,7 +35,6 @@ # set os.environ['AWS_REGION_NAME'] = class SagemakerLLM(BaseAWSLLM): - def _load_credentials( self, optional_params: dict, @@ -154,7 +153,6 @@ def completion( # noqa: PLR0915 acompletion: bool = False, headers: dict = {}, ): - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) @@ -437,10 +435,14 @@ async def async_streaming( prepared_request.headers.update( {"X-Amzn-SageMaker-Inference-Component": model_id} ) + + if not prepared_request.body: + raise ValueError("Prepared request body is empty") + completion_stream = await self.make_async_call( api_base=prepared_request.url, headers=prepared_request.headers, # type: ignore - data=prepared_request.body, + data=cast(str, prepared_request.body), logging_obj=logging_obj, ) streaming_response = CustomStreamWrapper( diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index d0ab5d069760..9923c0e45d0a 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -88,9 +88,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py index 47143769794f..1fdb772addec 100644 --- a/litellm/llms/together_ai/rerank/transformation.py +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -19,7 +19,6 @@ class TogetherAIRerankConfig: def _transform_response(self, response: dict) -> RerankResponse: - _billed_units = RerankBilledUnits(**response.get("usage", {})) _tokens = RerankTokens(**response.get("usage", {})) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py index 8b95deed0466..6188101015ef 100644 --- a/litellm/llms/topaz/image_variations/transformation.py +++ b/litellm/llms/topaz/image_variations/transformation.py @@ -121,7 +121,6 @@ def transform_request_image_variation( optional_params: dict, headers: dict, ) -> HttpHandlerRequestFields: - request_params = HttpHandlerRequestFields( files={"image": self.prepare_file_tuple(image)}, data=optional_params, @@ -134,7 +133,6 @@ def _common_transform_response_image_variation( image_content: bytes, response_ms: float, ) -> ImageResponse: - # Convert to base64 base64_image = base64.b64encode(image_content).decode("utf-8") diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 56151f89ef01..46b607d4556c 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -244,7 +244,6 @@ def transform_request( litellm_params: dict, headers: dict, ) -> dict: - text_input = messages[0].get("content", "") data_for_triton = { "inputs": [ diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index b82268bef6ae..dc3f93857aa1 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -35,7 +35,6 @@ def create_batch( timeout: Union[float, httpx.Timeout], max_retries: Optional[int], ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: - sync_handler = _get_httpx_client() access_token, project_id = self._ensure_access_token( @@ -69,10 +68,8 @@ def create_batch( "Authorization": f"Bearer {access_token}", } - vertex_batch_request: VertexAIBatchPredictionJob = ( - VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( - request=create_batch_data - ) + vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( + request=create_batch_data ) if _is_async is True: diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 337445777ad5..f2cd1ef557bc 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -243,7 +243,7 @@ def convert_anyof_null_to_nullable(schema, depth=0): # remove null type anyof.remove(atype) contains_null = True - + if len(anyof) == 0: # Edge case: response schema with only null type present is invalid in Vertex AI raise ValueError( @@ -251,12 +251,10 @@ def convert_anyof_null_to_nullable(schema, depth=0): "Please provide a non-null type." ) - if contains_null: # set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python for atype in anyof: atype["nullable"] = True - properties = schema.get("properties", None) if properties is not None: diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py index 266169cdfb47..7000cf151d90 100644 --- a/litellm/llms/vertex_ai/files/handler.py +++ b/litellm/llms/vertex_ai/files/handler.py @@ -49,10 +49,11 @@ async def async_create_file( service_account_json=gcs_logging_config["path_service_account"], ) bucket_name = gcs_logging_config["bucket_name"] - logging_payload, object_name = ( - vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content( - openai_file_content=create_file_data.get("file") - ) + ( + logging_payload, + object_name, + ) = vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content( + openai_file_content=create_file_data.get("file") ) gcs_upload_response = await self._log_json_data_on_gcs( headers=headers, diff --git a/litellm/llms/vertex_ai/fine_tuning/handler.py b/litellm/llms/vertex_ai/fine_tuning/handler.py index 3cf409c78e7f..7ea8527fd413 100644 --- a/litellm/llms/vertex_ai/fine_tuning/handler.py +++ b/litellm/llms/vertex_ai/fine_tuning/handler.py @@ -36,7 +36,6 @@ def __init__(self) -> None: def convert_response_created_at(self, response: ResponseTuningJob): try: - create_time_str = response.get("createTime", "") or "" create_time_datetime = datetime.fromisoformat( create_time_str.replace("Z", "+00:00") @@ -65,9 +64,9 @@ def convert_openai_request_to_vertex( ) if create_fine_tuning_job_data.validation_file: - supervised_tuning_spec["validation_dataset"] = ( - create_fine_tuning_job_data.validation_file - ) + supervised_tuning_spec[ + "validation_dataset" + ] = create_fine_tuning_job_data.validation_file _vertex_hyperparameters = ( self._transform_openai_hyperparameters_to_vertex_hyperparameters( @@ -175,7 +174,6 @@ async def acreate_fine_tuning_job( headers: dict, request_data: FineTuneJobCreate, ): - try: verbose_logger.debug( "about to create fine tuning job: %s, request_data: %s", @@ -229,7 +227,6 @@ def create_fine_tuning_job( kwargs: Optional[dict] = None, original_hyperparameters: Optional[dict] = {}, ): - verbose_logger.debug( "creating fine tuning job, args= %s", create_fine_tuning_job_data ) @@ -346,9 +343,9 @@ async def pass_through_vertex_ai_POST_request( elif "cachedContents" in request_route: _model = request_data.get("model") if _model is not None and "/publishers/google/models/" not in _model: - request_data["model"] = ( - f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" - ) + request_data[ + "model" + ] = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}" else: diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index d6bafc7c60f3..96b33ee1875e 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -85,7 +85,6 @@ def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartT and (image_type := format or _get_image_mime_type_from_url(image_url)) is not None ): - file_data = FileDataType(file_uri=image_url, mime_type=image_type) return PartType(file_data=file_data) elif "http://" in image_url or "https://" in image_url or "base64" in image_url: @@ -414,18 +413,19 @@ async def async_transform_request_body( context_caching_endpoints = ContextCachingEndpoints() if gemini_api_key is not None: - messages, cached_content = ( - await context_caching_endpoints.async_check_and_create_cache( - messages=messages, - api_key=gemini_api_key, - api_base=api_base, - model=model, - client=client, - timeout=timeout, - extra_headers=extra_headers, - cached_content=optional_params.pop("cached_content", None), - logging_obj=logging_obj, - ) + ( + messages, + cached_content, + ) = await context_caching_endpoints.async_check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, ) else: # [TODO] implement context caching for gemini as well cached_content = optional_params.pop("cached_content", None) diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 90c66f69a3dd..860dec9eb2d1 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -246,9 +246,9 @@ def _map_function(self, value: List[dict]) -> List[Tools]: value = _remove_strict_from_schema(value) for tool in value: - openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( - None - ) + openai_function_object: Optional[ + ChatCompletionToolParamFunctionChunk + ] = None if "function" in tool: # tools list _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore **tool["function"] @@ -813,15 +813,15 @@ def transform_response( ## ADD SAFETY RATINGS ## setattr(model_response, "vertex_ai_safety_results", safety_ratings) - model_response._hidden_params["vertex_ai_safety_results"] = ( - safety_ratings # older approach - maintaining to prevent regressions - ) + model_response._hidden_params[ + "vertex_ai_safety_results" + ] = safety_ratings # older approach - maintaining to prevent regressions ## ADD CITATION METADATA ## setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) - model_response._hidden_params["vertex_ai_citation_metadata"] = ( - citation_metadata # older approach - maintaining to prevent regressions - ) + model_response._hidden_params[ + "vertex_ai_citation_metadata" + ] = citation_metadata # older approach - maintaining to prevent regressions except Exception as e: raise VertexAIError( diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py index 0fe5145a1443..ecfe2ee8b4bb 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py @@ -47,7 +47,6 @@ def batch_embeddings( timeout=300, client=None, ) -> EmbeddingResponse: - _auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py index 592dac584616..2c0f5dad2280 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py @@ -52,7 +52,6 @@ def process_response( model: str, _predictions: VertexAIBatchEmbeddingsResponseObject, ) -> EmbeddingResponse: - openai_embeddings: List[Embedding] = [] for embedding in _predictions["embeddings"]: openai_embedding = Embedding( diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py index 34879ae9acd3..88d7339449e7 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -50,7 +50,6 @@ def multimodal_embedding( timeout=300, client=None, ) -> EmbeddingResponse: - _auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py index 8f96ca2bb025..afa58c7e5cf0 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py @@ -260,7 +260,6 @@ def calculate_usage( def transform_embedding_response_to_openai( self, predictions: MultimodalPredictions ) -> List[Embedding]: - openai_embeddings: List[Embedding] = [] if "predictions" in predictions: for idx, _prediction in enumerate(predictions["predictions"]): diff --git a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py index 744e1eb3177d..df267d9623b2 100644 --- a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py @@ -323,7 +323,6 @@ def completion( # noqa: PLR0915 ) completion_response = chat.send_message(prompt, **optional_params).text elif mode == "text": - if fake_stream is not True and stream is True: request_str += ( f"llm_model.predict_streaming({prompt}, **{optional_params})\n" @@ -506,7 +505,6 @@ async def async_completion( # noqa: PLR0915 Add support for acompletion calls for gemini-pro """ try: - response_obj = None completion_response = None if mode == "chat": diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py index fb2393631bc2..b8d2658f8076 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py @@ -110,7 +110,6 @@ def completion( message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", ) try: - vertex_httpx_logic = VertexLLM() access_token, project_id = vertex_httpx_logic._ensure_access_token( diff --git a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py index 2e8051d4d2ef..1167ca285fc9 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py @@ -86,10 +86,8 @@ def embedding( mode="embedding", ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = ( - litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model - ) + vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model ) _client_params = {} @@ -178,10 +176,8 @@ async def async_embedding( mode="embedding", ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = ( - litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model - ) + vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model ) _async_client_params = {} diff --git a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py index d9e84fca038c..97af558041d3 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py @@ -212,7 +212,6 @@ def transform_vertex_response_to_openai( embedding_response = [] input_tokens: int = 0 for idx, element in enumerate(_predictions): - embedding = element["embeddings"] embedding_response.append( { diff --git a/litellm/llms/vertex_ai/vertex_model_garden/main.py b/litellm/llms/vertex_ai/vertex_model_garden/main.py index 7b54d4e34b9c..1c57096734b4 100644 --- a/litellm/llms/vertex_ai/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai/vertex_model_garden/main.py @@ -76,7 +76,6 @@ def completion( VertexLLM, ) except Exception as e: - raise VertexAIError( status_code=400, message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""", diff --git a/litellm/llms/watsonx/chat/transformation.py b/litellm/llms/watsonx/chat/transformation.py index f253da6f5b47..2ff1dd6a68be 100644 --- a/litellm/llms/watsonx/chat/transformation.py +++ b/litellm/llms/watsonx/chat/transformation.py @@ -15,7 +15,6 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): - def get_supported_openai_params(self, model: str) -> List: return [ "temperature", # equivalent to temperature diff --git a/litellm/main.py b/litellm/main.py index 1e6c36aa6c6e..f69454aaad87 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -946,14 +946,16 @@ def completion( # type: ignore # noqa: PLR0915 ## PROMPT MANAGEMENT HOOKS ## if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: - model, messages, optional_params = ( - litellm_logging_obj.get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_variables=prompt_variables, - ) + ( + model, + messages, + optional_params, + ) = litellm_logging_obj.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_variables=prompt_variables, ) try: @@ -1246,7 +1248,6 @@ def completion( # type: ignore # noqa: PLR0915 optional_params["max_retries"] = max_retries if litellm.AzureOpenAIO1Config().is_o_series_model(model=model): - ## LOAD CONFIG - if set config = litellm.AzureOpenAIO1Config.get_config() for k, v in config.items(): @@ -2654,9 +2655,9 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name" not in optional_params or optional_params["aws_region_name"] is None ): - optional_params["aws_region_name"] = ( - aws_bedrock_client.meta.region_name - ) + optional_params[ + "aws_region_name" + ] = aws_bedrock_client.meta.region_name bedrock_route = BedrockModelInfo.get_bedrock_route(model) if bedrock_route == "converse": @@ -4362,9 +4363,9 @@ def adapter_completion( new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( - None - ) + translated_response: Optional[ + Union[BaseModel, AdapterCompletionStreamWrapper] + ] = None if isinstance(response, ModelResponse): translated_response = translation_obj.translate_completion_output_params( response=response @@ -4436,13 +4437,16 @@ async def amoderation( optional_params = GenericLiteLLMParams(**kwargs) try: - model, _custom_llm_provider, _dynamic_api_key, _dynamic_api_base = ( - litellm.get_llm_provider( - model=model or "", - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + _dynamic_api_key, + _dynamic_api_base, + ) = litellm.get_llm_provider( + model=model or "", + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) except litellm.BadRequestError: # `model` is optional field for moderation - get_llm_provider will throw BadRequestError if model is not set / not recognized @@ -5405,7 +5409,6 @@ def speech( # noqa: PLR0915 litellm_params=litellm_params_dict, ) elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": - generic_optional_params = GenericLiteLLMParams(**kwargs) api_base = generic_optional_params.api_base or "" @@ -5460,7 +5463,6 @@ def speech( # noqa: PLR0915 async def ahealth_check_wildcard_models( model: str, custom_llm_provider: str, model_params: dict ) -> dict: - # this is a wildcard model, we need to pick a random model from the provider cheapest_models = pick_cheapest_chat_models_from_llm_provider( custom_llm_provider=custom_llm_provider, n=3 @@ -5783,9 +5785,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(content_chunks) > 0: - response["choices"][0]["message"]["content"] = ( - processor.get_combined_content(content_chunks) - ) + response["choices"][0]["message"][ + "content" + ] = processor.get_combined_content(content_chunks) audio_chunks = [ chunk diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e7d6bec00443..394f49df7a23 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -615,9 +615,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -873,12 +873,12 @@ class NewCustomerRequest(BudgetNewRequest): alias: Optional[str] = None # human-friendly alias blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -900,12 +900,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1040,9 +1040,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1299,9 +1299,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class ConfigGeneralSettings(LiteLLMPydanticObjectBase): @@ -1567,9 +1567,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -2306,9 +2306,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -2497,9 +2497,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -2627,9 +2627,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index efbfe8d90cb0..4fd718351962 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -551,7 +551,6 @@ def _get_role_based_permissions( return None for role_based_permission in role_based_permissions: - if role_based_permission.role == rbac_role: return getattr(role_based_permission, key) @@ -867,7 +866,6 @@ async def _get_team_object_from_cache( proxy_logging_obj is not None and proxy_logging_obj.internal_usage_cache.dual_cache ): - cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( key=key, parent_otel_span=parent_otel_span @@ -1202,7 +1200,6 @@ async def can_user_call_model( llm_router: Optional[Router], user_object: Optional[LiteLLM_UserTable], ) -> Literal[True]: - if user_object is None: return True diff --git a/litellm/proxy/auth/auth_checks_organization.py b/litellm/proxy/auth/auth_checks_organization.py index 3da3d8ddd1e3..e96a5c61fc05 100644 --- a/litellm/proxy/auth/auth_checks_organization.py +++ b/litellm/proxy/auth/auth_checks_organization.py @@ -44,9 +44,10 @@ def organization_role_based_access_check( # Checks if route is an Org Admin Only Route if route in LiteLLMRoutes.org_admin_only_routes.value: - _user_organizations, _user_organization_role_mapping = ( - get_user_organization_info(user_object) - ) + ( + _user_organizations, + _user_organization_role_mapping, + ) = get_user_organization_info(user_object) if user_object.organization_memberships is None: raise ProxyException( @@ -84,9 +85,10 @@ def organization_role_based_access_check( ) elif route == "/team/new": # if user is part of multiple teams, then they need to specify the organization_id - _user_organizations, _user_organization_role_mapping = ( - get_user_organization_info(user_object) - ) + ( + _user_organizations, + _user_organization_role_mapping, + ) = get_user_organization_info(user_object) if ( user_object.organization_memberships is not None and len(user_object.organization_memberships) > 0 diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index 05797381c6cf..5dd30a075849 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -23,7 +23,6 @@ class UserAPIKeyAuthExceptionHandler: - @staticmethod async def _handle_authentication_error( e: Exception, diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 2c4b122d3a65..0200457ef9ff 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -14,7 +14,6 @@ def _get_request_ip_address( request: Request, use_x_forwarded_for: Optional[bool] = False ) -> Optional[str]: - client_ip = None if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: client_ip = request.headers["x-forwarded-for"] @@ -469,7 +468,6 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool: from litellm.proxy.proxy_server import general_settings, premium_user if premium_user is not True: - return False # premium use has opted into using client credentials diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index cc410501989f..783c2f15530c 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -166,7 +166,6 @@ def get_end_user_id( self, token: dict, default_value: Optional[str] ) -> Optional[str]: try: - if self.litellm_jwtauth.end_user_id_jwt_field is not None: user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] else: @@ -339,7 +338,6 @@ def get_scopes(self, token: dict) -> List[str]: return scopes async def get_public_key(self, kid: Optional[str]) -> dict: - keys_url = os.getenv("JWT_PUBLIC_KEY_URL") if keys_url is None: @@ -348,7 +346,6 @@ async def get_public_key(self, kid: Optional[str]) -> dict: keys_url_list = [url.strip() for url in keys_url.split(",")] for key_url in keys_url_list: - cache_key = f"litellm_jwt_auth_keys_{key_url}" cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) @@ -923,7 +920,6 @@ async def auth_builder( object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) if rbac_role and object_id: - if rbac_role == LitellmUserRoles.TEAM: team_id = object_id elif rbac_role == LitellmUserRoles.INTERNAL_USER: @@ -940,15 +936,16 @@ async def auth_builder( ## SPECIFIC TEAM ID if not team_id: - team_id, team_object = ( - await JWTAuthManager.find_and_validate_specific_team_id( - jwt_handler, - jwt_valid_token, - prisma_client, - user_api_key_cache, - parent_otel_span, - proxy_logging_obj, - ) + ( + team_id, + team_object, + ) = await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler, + jwt_valid_token, + prisma_client, + user_api_key_cache, + parent_otel_span, + proxy_logging_obj, ) if not team_object and not team_id: diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py index 67ec91f51af7..d962aad2c00c 100644 --- a/litellm/proxy/auth/litellm_license.py +++ b/litellm/proxy/auth/litellm_license.py @@ -45,7 +45,6 @@ def read_public_key(self): verbose_proxy_logger.error(f"Error reading public key: {str(e)}") def _verify(self, license_str: str) -> bool: - verbose_proxy_logger.debug( "litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format( self.base_url, license_str diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index a48ef6ae8726..f0f730138fd8 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -178,7 +178,6 @@ def _get_wildcard_models( all_wildcard_models = [] for model in unique_models: if _check_wildcard_routing(model=model): - if ( return_wildcard_routes ): # will add the wildcard route to the list eg: anthropic/*. diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 8f956abb72f9..41529512b6b9 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -16,7 +16,6 @@ class RouteChecks: - @staticmethod def non_proxy_admin_allowed_routes_check( user_obj: Optional[LiteLLM_UserTable], @@ -67,7 +66,6 @@ def non_proxy_admin_allowed_routes_check( and getattr(valid_token, "permissions", None) is not None and "get_spend_routes" in getattr(valid_token, "permissions", []) ): - pass elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value: if RouteChecks.is_llm_api_route(route=route): @@ -80,7 +78,6 @@ def non_proxy_admin_allowed_routes_check( ): # the Admin Viewer is only allowed to call /user/update for their own user_id and can only update if route == "/user/update": - # Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY if request_data is not None and isinstance(request_data, dict): _params_updated = request_data.keys() diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b58353bf056d..eddbf4e0d976 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -206,7 +206,6 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: def get_model_from_request(request_data: dict, route: str) -> Optional[str]: - # First try to get model from request_data model = request_data.get("model") @@ -229,7 +228,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_apim_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: - from litellm.proxy.proxy_server import ( general_settings, jwt_handler, @@ -251,7 +249,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 valid_token: Optional[UserAPIKeyAuth] = None try: - # get the request body await pre_db_read_auth_checks( @@ -514,23 +511,23 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 proxy_logging_obj=proxy_logging_obj, ) if _end_user_object is not None: - end_user_params["allowed_model_region"] = ( - _end_user_object.allowed_model_region - ) + end_user_params[ + "allowed_model_region" + ] = _end_user_object.allowed_model_region if _end_user_object.litellm_budget_table is not None: budget_info = _end_user_object.litellm_budget_table if budget_info.tpm_limit is not None: - end_user_params["end_user_tpm_limit"] = ( - budget_info.tpm_limit - ) + end_user_params[ + "end_user_tpm_limit" + ] = budget_info.tpm_limit if budget_info.rpm_limit is not None: - end_user_params["end_user_rpm_limit"] = ( - budget_info.rpm_limit - ) + end_user_params[ + "end_user_rpm_limit" + ] = budget_info.rpm_limit if budget_info.max_budget is not None: - end_user_params["end_user_max_budget"] = ( - budget_info.max_budget - ) + end_user_params[ + "end_user_max_budget" + ] = budget_info.max_budget except Exception as e: if isinstance(e, litellm.BudgetExceededError): raise e @@ -801,7 +798,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Check 3. Check if user is in their team budget if valid_token.team_member_spend is not None: if prisma_client is not None: - _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" team_member_info = await user_api_key_cache.async_get_cache( diff --git a/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/litellm/proxy/common_utils/encrypt_decrypt_utils.py index ec9279a089b1..34527348670f 100644 --- a/litellm/proxy/common_utils/encrypt_decrypt_utils.py +++ b/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -21,7 +21,6 @@ def _get_salt_key(): def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): - signing_key = new_encryption_key or _get_salt_key() try: @@ -41,7 +40,6 @@ def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): def decrypt_value_helper(value: str): - signing_key = _get_salt_key() try: diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 5736ee215278..7220ccaa6514 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -142,10 +142,10 @@ def check_file_size_under_limit( if llm_router is not None and request_data["model"] in router_model_names: try: - deployment: Optional[Deployment] = ( - llm_router.get_deployment_by_model_group_name( - model_group_name=request_data["model"] - ) + deployment: Optional[ + Deployment + ] = llm_router.get_deployment_by_model_group_name( + model_group_name=request_data["model"] ) if ( deployment diff --git a/litellm/proxy/db/log_db_metrics.py b/litellm/proxy/db/log_db_metrics.py index 9bd335079324..5c7951553245 100644 --- a/litellm/proxy/db/log_db_metrics.py +++ b/litellm/proxy/db/log_db_metrics.py @@ -36,7 +36,6 @@ def log_db_metrics(func): @wraps(func) async def wrapper(*args, **kwargs): - start_time: datetime = datetime.now() try: diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index f77c839aafb8..f98fc9300f8c 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -43,9 +43,9 @@ def _should_commit_spend_updates_to_redis() -> bool: """ from litellm.proxy.proxy_server import general_settings - _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( - general_settings.get("use_redis_transaction_buffer", False) - ) + _use_redis_transaction_buffer: Optional[ + Union[bool, str] + ] = general_settings.get("use_redis_transaction_buffer", False) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) if _use_redis_transaction_buffer is None: @@ -78,15 +78,13 @@ async def store_in_memory_spend_updates_in_redis( "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" ) return - db_spend_update_transactions: DBSpendUpdateTransactions = ( - DBSpendUpdateTransactions( - user_list_transactions=prisma_client.user_list_transactions, - end_user_list_transactions=prisma_client.end_user_list_transactions, - key_list_transactions=prisma_client.key_list_transactions, - team_list_transactions=prisma_client.team_list_transactions, - team_member_list_transactions=prisma_client.team_member_list_transactions, - org_list_transactions=prisma_client.org_list_transactions, - ) + db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions( + user_list_transactions=prisma_client.user_list_transactions, + end_user_list_transactions=prisma_client.end_user_list_transactions, + key_list_transactions=prisma_client.key_list_transactions, + team_list_transactions=prisma_client.team_list_transactions, + team_member_list_transactions=prisma_client.team_member_list_transactions, + org_list_transactions=prisma_client.org_list_transactions, ) # only store in redis if there are any updates to commit diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index c351f9f762a4..e97031146035 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -45,7 +45,6 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b # v1 implementation of this if isinstance(request_guardrails, dict): - # get guardrail configs from `init_guardrails.py` # for all requested guardrails -> get their associated callbacks for _guardrail_name, should_run in request_guardrails.items(): diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 7686fba7cf6a..5c6b53be251c 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -192,7 +192,6 @@ def _prepare_request( async def make_bedrock_api_request( self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None ): - credentials, aws_region_name = self._load_credentials() bedrock_request_data: dict = dict( self.convert_to_bedrock_format( diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 5d3b8be33410..2dd8a3154a8a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -148,9 +148,9 @@ async def _check( # noqa: PLR0915 text = "" _json_data: str = "" if "messages" in data and isinstance(data["messages"], list): - prompt_injection_obj: Optional[GuardrailItem] = ( - litellm.guardrail_name_config_map.get("prompt_injection") - ) + prompt_injection_obj: Optional[ + GuardrailItem + ] = litellm.guardrail_name_config_map.get("prompt_injection") if prompt_injection_obj is not None: enabled_roles = prompt_injection_obj.enabled_roles else: diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 86d2c8b25add..0c7d2a1fe60f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -95,9 +95,11 @@ def validate_environment( presidio_analyzer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None, ): - self.presidio_analyzer_api_base: Optional[str] = ( - presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore - ) + self.presidio_analyzer_api_base: Optional[ + str + ] = presidio_analyzer_api_base or get_secret( + "PRESIDIO_ANALYZER_API_BASE", None + ) # type: ignore self.presidio_anonymizer_api_base: Optional[ str ] = presidio_anonymizer_api_base or litellm.get_secret( @@ -168,7 +170,6 @@ async def check_pii( async with session.post( analyze_url, json=analyze_payload ) as response: - analyze_results = await response.json() # Make the second request to /anonymize @@ -228,7 +229,6 @@ async def async_pre_call_hook( """ try: - content_safety = data.get("content_safety", None) verbose_proxy_logger.debug("content_safety: %s", content_safety) presidio_config = self.get_presidio_settings_from_request_data(data) diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 15a9bc1ba81a..e06366d02b56 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -71,7 +71,6 @@ async def async_set_cache_sadd(self, model: str, value: List): class _PROXY_DynamicRateLimitHandler(CustomLogger): - # Class variables or attributes def __init__(self, internal_usage_cache: DualCache): self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache) @@ -121,12 +120,13 @@ async def check_available_usage( active_projects = await self.internal_usage_cache.async_get_cache( model=model ) - current_model_tpm, current_model_rpm = ( - await self.llm_router.get_model_group_usage(model_group=model) - ) - model_group_info: Optional[ModelGroupInfo] = ( - self.llm_router.get_model_group_info(model_group=model) - ) + ( + current_model_tpm, + current_model_rpm, + ) = await self.llm_router.get_model_group_usage(model_group=model) + model_group_info: Optional[ + ModelGroupInfo + ] = self.llm_router.get_model_group_info(model_group=model) total_model_tpm: Optional[int] = None total_model_rpm: Optional[int] = None if model_group_info is not None: @@ -210,10 +210,14 @@ async def async_pre_call_hook( key_priority: Optional[str] = user_api_key_dict.metadata.get( "priority", None ) - available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( - await self.check_available_usage( - model=data["model"], priority=key_priority - ) + ( + available_tpm, + available_rpm, + model_tpm, + model_rpm, + active_projects, + ) = await self.check_available_usage( + model=data["model"], priority=key_priority ) ### CHECK TPM ### if available_tpm is not None and available_tpm == 0: @@ -267,21 +271,25 @@ async def async_post_call_success_hook( key_priority: Optional[str] = user_api_key_dict.metadata.get( "priority", None ) - available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( - await self.check_available_usage( - model=model_info["model_name"], priority=key_priority - ) - ) - response._hidden_params["additional_headers"] = ( - { # Add additional response headers - easier debugging - "x-litellm-model_group": model_info["model_name"], - "x-ratelimit-remaining-litellm-project-tokens": available_tpm, - "x-ratelimit-remaining-litellm-project-requests": available_rpm, - "x-ratelimit-remaining-model-tokens": model_tpm, - "x-ratelimit-remaining-model-requests": model_rpm, - "x-ratelimit-current-active-projects": active_projects, - } + ( + available_tpm, + available_rpm, + model_tpm, + model_rpm, + active_projects, + ) = await self.check_available_usage( + model=model_info["model_name"], priority=key_priority ) + response._hidden_params[ + "additional_headers" + ] = { # Add additional response headers - easier debugging + "x-litellm-model_group": model_info["model_name"], + "x-ratelimit-remaining-litellm-project-tokens": available_tpm, + "x-ratelimit-remaining-litellm-project-requests": available_rpm, + "x-ratelimit-remaining-model-tokens": model_tpm, + "x-ratelimit-remaining-model-requests": model_rpm, + "x-ratelimit-current-active-projects": active_projects, + } return response return await super().async_post_call_success_hook( diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index 2030cb2a4521..c2c4f0669f3d 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -28,7 +28,6 @@ class KeyManagementEventHooks: - @staticmethod async def async_key_generated_hook( data: GenerateKeyRequest, diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 06f3b6afe512..83b3c7179aca 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -201,7 +201,9 @@ async def async_pre_call_hook( # noqa: PLR0915 if rpm_limit is None: rpm_limit = sys.maxsize - values_to_update_in_cache: List[Tuple[Any, Any]] = ( + values_to_update_in_cache: List[ + Tuple[Any, Any] + ] = ( [] ) # values that need to get updated in cache, will run a batch_set_cache after this function @@ -678,9 +680,9 @@ async def async_log_success_event( # noqa: PLR0915 async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose("Inside Max Parallel Request Failure Hook") - litellm_parent_otel_span: Union[Span, None] = ( - _get_parent_otel_span_from_kwargs(kwargs=kwargs) - ) + litellm_parent_otel_span: Union[ + Span, None + ] = _get_parent_otel_span_from_kwargs(kwargs=kwargs) _metadata = kwargs["litellm_params"].get("metadata", {}) or {} global_max_parallel_requests = _metadata.get( "global_max_parallel_requests", None @@ -807,11 +809,11 @@ async def async_post_call_success_hook( current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{api_key}::{precise_minute}::request_count" - current: Optional[CurrentItemRateLimit] = ( - await self.internal_usage_cache.async_get_cache( - key=request_count_api_key, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + current: Optional[ + CurrentItemRateLimit + ] = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) key_remaining_rpm_limit: Optional[int] = None @@ -843,15 +845,15 @@ async def async_post_call_success_hook( _additional_headers = _hidden_params.get("additional_headers", {}) or {} if key_remaining_rpm_limit is not None: - _additional_headers["x-ratelimit-remaining-requests"] = ( - key_remaining_rpm_limit - ) + _additional_headers[ + "x-ratelimit-remaining-requests" + ] = key_remaining_rpm_limit if key_rpm_limit is not None: _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit if key_remaining_tpm_limit is not None: - _additional_headers["x-ratelimit-remaining-tokens"] = ( - key_remaining_tpm_limit - ) + _additional_headers[ + "x-ratelimit-remaining-tokens" + ] = key_remaining_tpm_limit if key_tpm_limit is not None: _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index b1b2bbee5c43..b8fa8466a32f 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -196,7 +196,6 @@ async def async_pre_call_hook( return data except HTTPException as e: - if ( e.status_code == 400 and isinstance(e.detail, dict) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index f205b0146fe1..39c1eeace983 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -51,10 +51,10 @@ async def async_post_call_failure_hook( ) _metadata["user_api_key"] = user_api_key_dict.api_key _metadata["status"] = "failure" - _metadata["error_information"] = ( - StandardLoggingPayloadSetup.get_error_information( - original_exception=original_exception, - ) + _metadata[ + "error_information" + ] = StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, ) existing_metadata: dict = request_data.get("metadata", None) or {} diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index ece5ecf4b7cc..6427be5a6e25 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -346,11 +346,11 @@ def add_key_level_controls( ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name]["tags"] = ( - LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], - ) + data[_metadata_variable_name][ + "tags" + ] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], ) if "spend_logs_metadata" in key_metadata and isinstance( key_metadata["spend_logs_metadata"], dict @@ -556,9 +556,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name]["global_max_parallel_requests"] = ( - general_settings.get("global_max_parallel_requests", None) - ) + data[_metadata_variable_name][ + "global_max_parallel_requests" + ] = general_settings.get("global_max_parallel_requests", None) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20aa1c6bbf00..65b0156afe49 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -197,7 +197,6 @@ async def budget_settings( for field_name, field_info in BudgetNewRequest.model_fields.items(): if field_name in allowed_args: - _stored_in_db = True _response_obj = ConfigList( diff --git a/litellm/proxy/management_endpoints/common_utils.py b/litellm/proxy/management_endpoints/common_utils.py index d80a06c597c2..550ff44616cf 100644 --- a/litellm/proxy/management_endpoints/common_utils.py +++ b/litellm/proxy/management_endpoints/common_utils.py @@ -16,7 +16,6 @@ def _is_user_team_admin( if ( member.user_id is not None and member.user_id == user_api_key_dict.user_id ) and member.role == "admin": - return True return False diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 976ff8581f48..1f6f846bc774 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -230,7 +230,6 @@ async def new_end_user( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) try: - ## VALIDATION ## if data.default_model is not None: if llm_router is None: diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 79de6da1fd11..90444013a8f9 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -82,9 +82,9 @@ def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> d data_json["user_id"] = str(uuid.uuid4()) auto_create_key = data_json.pop("auto_create_key", True) if auto_create_key is False: - data_json["table_name"] = ( - "user" # only create a user, don't create key if 'auto_create_key' set to False - ) + data_json[ + "table_name" + ] = "user" # only create a user, don't create key if 'auto_create_key' set to False is_internal_user = False if data.user_role and data.user_role.is_internal_user_role: @@ -370,7 +370,6 @@ async def ui_get_available_role( _data_to_return = {} for role in LitellmUserRoles: - # We only show a subset of roles on UI if role in [ LitellmUserRoles.PROXY_ADMIN, @@ -652,9 +651,9 @@ def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> di "budget_duration" not in non_default_values ): # applies internal user limits, if user role updated if is_internal_user and litellm.internal_user_budget_duration is not None: - non_default_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) + non_default_values[ + "budget_duration" + ] = litellm.internal_user_budget_duration duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) @@ -965,13 +964,13 @@ async def get_users( "in": user_id_list, # Now passing a list of strings as required by Prisma } - users: Optional[List[LiteLLM_UserTable]] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, - ) + users: Optional[ + List[LiteLLM_UserTable] + ] = await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, ) # Get total count of user rows @@ -1226,13 +1225,13 @@ async def ui_view_users( } # Query users with pagination and filters - users: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, - ) + users: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, ) if not users: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9141d9d14a34..b0bf1fb6194e 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -227,7 +227,6 @@ def _personal_key_membership_check( def _personal_key_generation_check( user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest ): - if ( litellm.key_generation_settings is None or litellm.key_generation_settings.get("personal_key_generation") is None @@ -568,9 +567,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -1448,10 +1447,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) # Assuming 'db' is your Prisma Client instance @@ -1553,9 +1552,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -1677,7 +1676,6 @@ async def regenerate_key_fn( Note: This is an Enterprise feature. It requires a premium license to use. """ try: - from litellm.proxy.proxy_server import ( hash_token, master_key, @@ -1824,7 +1822,6 @@ async def validate_key_list_check( key_alias: Optional[str], prisma_client: PrismaClient, ) -> Optional[LiteLLM_UserTable]: - if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: return None @@ -1835,11 +1832,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -1900,10 +1897,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 88245e36d18e..0e8a9e4cc87b 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -399,7 +399,6 @@ async def can_user_make_model_call( prisma_client: PrismaClient, premium_user: bool, ) -> Literal[True]: - ## Check team model auth if ( model_params.model_info is not None @@ -579,7 +578,6 @@ async def add_new_model( ) try: - if prisma_client is None: raise HTTPException( status_code=500, @@ -717,7 +715,6 @@ async def update_model( ) try: - if prisma_client is None: raise HTTPException( status_code=500, diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index c202043fbe0c..37de12a9d291 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -358,11 +358,11 @@ async def info_organization(organization_id: str): if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) - response: Optional[LiteLLM_OrganizationTableWithMembers] = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": organization_id}, - include={"litellm_budget_table": True, "members": True, "teams": True}, - ) + response: Optional[ + LiteLLM_OrganizationTableWithMembers + ] = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id}, + include={"litellm_budget_table": True, "members": True, "teams": True}, ) if response is None: @@ -486,12 +486,13 @@ async def organization_member_add( updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] = [] for member in members: - updated_user, updated_organization_membership = ( - await add_member_to_organization( - member=member, - organization_id=data.organization_id, - prisma_client=prisma_client, - ) + ( + updated_user, + updated_organization_membership, + ) = await add_member_to_organization( + member=member, + organization_id=data.organization_id, + prisma_client=prisma_client, ) updated_users.append(updated_user) @@ -657,16 +658,16 @@ async def organization_member_update( }, data={"budget_id": budget_id}, ) - final_organization_membership: Optional[BaseModel] = ( - await prisma_client.db.litellm_organizationmembership.find_unique( - where={ - "user_id_organization_id": { - "user_id": data.user_id, - "organization_id": data.organization_id, - } - }, - include={"litellm_budget_table": True}, - ) + final_organization_membership: Optional[ + BaseModel + ] = await prisma_client.db.litellm_organizationmembership.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + include={"litellm_budget_table": True}, ) if final_organization_membership is None: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index f5bcc6ba11c7..842b5c8e75f3 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -506,12 +506,12 @@ async def update_team( updated_kv["model_id"] = _model_id updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) - team_row: Optional[LiteLLM_TeamTable] = ( - await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data=updated_kv, - include={"litellm_model_table": True}, # type: ignore - ) + team_row: Optional[ + LiteLLM_TeamTable + ] = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data=updated_kv, + include={"litellm_model_table": True}, # type: ignore ) if team_row is None or team_row.team_id is None: @@ -1137,10 +1137,10 @@ async def delete_team( team_rows: List[LiteLLM_TeamTable] = [] for team_id in data.team_ids: try: - team_row_base: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) + team_row_base: Optional[ + BaseModel + ] = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} ) if team_row_base is None: raise Exception @@ -1298,10 +1298,10 @@ async def team_info( ) try: - team_info: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) + team_info: Optional[ + BaseModel + ] = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} ) if team_info is None: raise Exception diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 86dec9fcafef..d38ff6b536d1 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -213,9 +213,9 @@ async def google_login(request: Request): # noqa: PLR0915 if state: redirect_params["state"] = state elif "okta" in generic_authorization_endpoint: - redirect_params["state"] = ( - uuid.uuid4().hex - ) # set state param for okta - required + redirect_params[ + "state" + ] = uuid.uuid4().hex # set state param for okta - required return await generic_sso.get_login_redirect(**redirect_params) # type: ignore elif ui_username is not None: # No Google, Microsoft SSO @@ -725,9 +725,9 @@ async def insert_sso_user( if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: - user_defined_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) + user_defined_values[ + "budget_duration" + ] = litellm.internal_user_budget_duration if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 69a5cf914198..cb8e079b76d7 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -179,7 +179,7 @@ def _delete_api_key_from_cache(kwargs): user_api_key_cache.delete_cache(key=update_request.key) # delete key request - if isinstance(update_request, KeyRequest): + if isinstance(update_request, KeyRequest) and update_request.keys: for key in update_request.keys: user_api_key_cache.delete_cache(key=key) pass @@ -251,7 +251,6 @@ async def send_management_endpoint_alert( proxy_logging_obj is not None and proxy_logging_obj.slack_alerting_instance is not None ): - # Virtual Key Events if function_name in management_function_to_event_name: _event_name: AlertType = management_function_to_event_name[function_name] diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index ffbca91c69d6..e810ba026e99 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -316,7 +316,6 @@ async def get_file_content( data: Dict = {} try: - # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 51845956fcdb..d6f2a01712be 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -25,7 +25,6 @@ class AnthropicPassthroughLoggingHandler: - @staticmethod def anthropic_passthrough_handler( httpx_response: httpx.Response, @@ -123,9 +122,9 @@ def _create_anthropic_response_logging_payload( litellm_model_response.id = logging_obj.litellm_call_id litellm_model_response.model = model logging_obj.model_call_details["model"] = model - logging_obj.model_call_details["custom_llm_provider"] = ( - litellm.LlmProviders.ANTHROPIC.value - ) + logging_obj.model_call_details[ + "custom_llm_provider" + ] = litellm.LlmProviders.ANTHROPIC.value return kwargs except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 9443563738af..a20f39e65ce9 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from urllib.parse import urlparse + import httpx import litellm @@ -222,7 +223,9 @@ def extract_model_from_url(url: str) -> str: @staticmethod def _get_custom_llm_provider_from_url(url: str) -> str: parsed_url = urlparse(url) - if parsed_url.hostname and parsed_url.hostname.endswith("generativelanguage.googleapis.com"): + if parsed_url.hostname and parsed_url.hostname.endswith( + "generativelanguage.googleapis.com" + ): return litellm.LlmProviders.GEMINI.value return litellm.LlmProviders.VERTEX_AI.value diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index a13b0dc216e0..a6b1b3e61499 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -373,7 +373,6 @@ async def pass_through_request( # noqa: PLR0915 litellm_call_id = str(uuid.uuid4()) url: Optional[httpx.URL] = None try: - from litellm.litellm_core_utils.litellm_logging import Logging from litellm.proxy.proxy_server import proxy_logging_obj @@ -384,7 +383,6 @@ async def pass_through_request( # noqa: PLR0915 ) if merge_query_params: - # Create a new URL with the merged query params url = url.copy_with( query=urlencode( @@ -771,7 +769,6 @@ def _is_streaming_response(response: httpx.Response) -> bool: async def initialize_pass_through_endpoints(pass_through_endpoints: list): - verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes from litellm.proxy.proxy_server import app, premium_user diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index 89cccfc07198..a02cacc3cc0d 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -130,9 +130,9 @@ def add_vertex_credentials( vertex_location=location, vertex_credentials=vertex_credentials, ) - self.deployment_key_to_vertex_credentials[deployment_key] = ( - vertex_pass_through_credentials - ) + self.deployment_key_to_vertex_credentials[ + deployment_key + ] = vertex_pass_through_credentials def _get_deployment_key( self, project_id: Optional[str], location: Optional[str] diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index b022bf1d25b9..2c11e4a2dd5a 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -21,7 +21,6 @@ class PassThroughStreamingHandler: - @staticmethod async def chunk_processor( response: httpx.Response, diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py index 22fa4da9deca..c1e2220c1591 100644 --- a/litellm/proxy/prisma_migration.py +++ b/litellm/proxy/prisma_migration.py @@ -9,8 +9,8 @@ sys.path.insert( 0, os.path.abspath("./") ) # Adds the parent directory to the system path -from litellm.secret_managers.aws_secret_manager import decrypt_env_var from litellm._logging import verbose_proxy_logger +from litellm.secret_managers.aws_secret_manager import decrypt_env_var if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV @@ -39,7 +39,9 @@ ) exit(1) else: - verbose_proxy_logger.info("Using existing DATABASE_URL environment variable") # Log existing DATABASE_URL + verbose_proxy_logger.info( + "Using existing DATABASE_URL environment variable" + ) # Log existing DATABASE_URL # Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations direct_url = os.getenv("DIRECT_URL") @@ -63,12 +65,18 @@ # run prisma generate verbose_proxy_logger.info("Running 'prisma generate'...") result = subprocess.run(["prisma", "generate"], capture_output=True, text=True) - verbose_proxy_logger.info(f"'prisma generate' stdout: {result.stdout}") # Log stdout + verbose_proxy_logger.info( + f"'prisma generate' stdout: {result.stdout}" + ) # Log stdout exit_code = result.returncode if exit_code != 0: - verbose_proxy_logger.info(f"'prisma generate' failed with exit code {exit_code}.") - verbose_proxy_logger.error(f"'prisma generate' stderr: {result.stderr}") # Log stderr + verbose_proxy_logger.info( + f"'prisma generate' failed with exit code {exit_code}." + ) + verbose_proxy_logger.error( + f"'prisma generate' stderr: {result.stderr}" + ) # Log stderr # Run the Prisma db push command verbose_proxy_logger.info("Running 'prisma db push --accept-data-loss'...") @@ -79,14 +87,20 @@ exit_code = result.returncode if exit_code != 0: - verbose_proxy_logger.info(f"'prisma db push' stderr: {result.stderr}") # Log stderr - verbose_proxy_logger.error(f"'prisma db push' failed with exit code {exit_code}.") + verbose_proxy_logger.info( + f"'prisma db push' stderr: {result.stderr}" + ) # Log stderr + verbose_proxy_logger.error( + f"'prisma db push' failed with exit code {exit_code}." + ) if retry_count < max_retries: verbose_proxy_logger.info("Retrying in 10 seconds...") time.sleep(10) if retry_count == max_retries and exit_code != 0: - verbose_proxy_logger.error(f"Unable to push database changes after {max_retries} retries.") + verbose_proxy_logger.error( + f"Unable to push database changes after {max_retries} retries." + ) exit(1) verbose_proxy_logger.info("Database push successful!") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 89c15f413db8..026e9a42cece 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -763,9 +763,9 @@ async def redirect_ui_middleware(request: Request, call_next): dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[RedisCache] = ( - None # redis cache used for tracking spend, tpm/rpm limits -) +redis_usage_cache: Optional[ + RedisCache +] = None # redis cache used for tracking spend, tpm/rpm limits user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -818,7 +818,6 @@ async def check_request_disconnection(request: Request, llm_api_call_task): while time.time() - start_time < 600: await asyncio.sleep(1) if await request.is_disconnected(): - # cancel the LLM API Call task if any passed - this is passed from individual providers # Example OpenAI, Azure, VertexAI etc llm_api_call_task.cancel() @@ -1092,9 +1091,9 @@ async def _update_team_cache(): _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[LiteLLM_TeamTable] = ( - await user_api_key_cache.async_get_cache(key=_id) - ) + existing_spend_obj: Optional[ + LiteLLM_TeamTable + ] = await user_api_key_cache.async_get_cache(key=_id) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -1610,7 +1609,6 @@ async def load_config( # noqa: PLR0915 litellm.guardrail_name_config_map = guardrail_name_config_map elif key == "callbacks": - initialize_callbacks_on_proxy( value=value, premium_user=premium_user, @@ -2765,9 +2763,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = ( - api_version # set this for azure - litellm can read this from the env - ) + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -2810,7 +2808,6 @@ async def async_assistants_data_generator( try: time.time() async with response as chunk: - ### CALL HOOKS ### - modify outgoing data chunk = await proxy_logging_obj.async_post_call_streaming_hook( user_api_key_dict=user_api_key_dict, response=chunk @@ -4675,7 +4672,6 @@ async def get_thread( global proxy_logging_obj data: Dict = {} try: - # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, @@ -6385,7 +6381,6 @@ async def alerting_settings( for field_name, field_info in SlackAlertingArgs.model_fields.items(): if field_name in allowed_args: - _stored_in_db: Optional[bool] = None if field_name in alerting_args_dict: _stored_in_db = True @@ -7333,7 +7328,6 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915 "success_callback" in updated_litellm_settings and "success_callback" in config["litellm_settings"] ): - # check both success callback are lists if isinstance( config["litellm_settings"]["success_callback"], list @@ -7588,7 +7582,6 @@ async def get_config_list( for field_name, field_info in ConfigGeneralSettings.model_fields.items(): if field_name in allowed_args: - ## HANDLE TYPED DICT typed_dict_type = allowed_args[field_name]["type"] @@ -7621,9 +7614,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[idx].field_description = ( - sub_field_info.description - ) + nested_fields[ + idx + ].field_description = sub_field_info.description idx += 1 _stored_in_db = None diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index 4c0e22aef790..4690b6cbd876 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -286,7 +286,6 @@ async def get_global_activity( user_api_key_dict, start_date_obj, end_date_obj ) else: - sql_query = """ SELECT date_trunc('day', "startTime") AS date, @@ -453,7 +452,6 @@ async def get_global_activity_model( user_api_key_dict, start_date_obj, end_date_obj ) else: - sql_query = """ SELECT model_group, @@ -1096,7 +1094,6 @@ async def get_global_spend_report( start_date_obj, end_date_obj, team_id, customer_id, prisma_client ) if group_by == "team": - # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj sql_query = """ @@ -1689,7 +1686,6 @@ async def ui_view_spend_logs( # noqa: PLR0915 ) try: - # Convert the date strings to datetime objects start_date_obj = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S").replace( tzinfo=timezone.utc @@ -2160,7 +2156,6 @@ async def global_spend_for_internal_user( code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) try: - user_id = user_api_key_dict.user_id if user_id is None: raise ValueError("/global/spend/logs Error: User ID is None") @@ -2293,7 +2288,6 @@ async def global_spend(): from litellm.proxy.proxy_server import prisma_client try: - total_spend = 0.0 if prisma_client is None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 4c82f586fb0c..7c73c45fb1b8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -491,7 +491,6 @@ async def pre_call_hook( try: for callback in litellm.callbacks: - _callback = None if isinstance(callback, str): _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( @@ -1197,9 +1196,9 @@ def add_spend_log_transaction_to_daily_user_transaction( api_requests=1, ) - self.daily_user_spend_transactions[daily_transaction_key] = ( - daily_transaction - ) + self.daily_user_spend_transactions[ + daily_transaction_key + ] = daily_transaction except Exception as e: raise e diff --git a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py index 0c91c326f555..684e2ad0617e 100644 --- a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py @@ -69,10 +69,10 @@ async def langfuse_proxy_route( request=request, api_key="Bearer {}".format(api_key) ) - callback_settings_obj: Optional[TeamCallbackMetadata] = ( - _get_dynamic_logging_metadata( - user_api_key_dict=user_api_key_dict, proxy_config=proxy_config - ) + callback_settings_obj: Optional[ + TeamCallbackMetadata + ] = _get_dynamic_logging_metadata( + user_api_key_dict=user_api_key_dict, proxy_config=proxy_config ) dynamic_langfuse_public_key: Optional[str] = None diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index ce8ae21c82bf..9307ce5a5500 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -107,13 +107,16 @@ def rerank( # noqa: PLR0915 k for k, v in unique_version_params.items() if v is not None ] - model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) rerank_provider_config: BaseRerankConfig = ( @@ -272,7 +275,6 @@ def rerank( # noqa: PLR0915 _is_async=_is_async, ) elif _custom_llm_provider == "jina_ai": - if dynamic_api_key is None: raise ValueError( "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" diff --git a/litellm/responses/main.py b/litellm/responses/main.py index aec2f8fe4a61..70b651f37661 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -165,21 +165,24 @@ def responses( # get llm provider logic litellm_params = GenericLiteLLMParams(**kwargs) - model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=litellm_params.api_base, - api_key=litellm_params.api_key, - ) + ( + model, + custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=litellm_params.api_base, + api_key=litellm_params.api_key, ) # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: diff --git a/litellm/router.py b/litellm/router.py index f739bc381dbe..d0755f9a0d36 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -333,9 +333,9 @@ def __init__( # noqa: PLR0915 ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( - "local" # default to an in-memory cache - ) + cache_type: Literal[ + "local", "redis", "redis-semantic", "s3", "disk" + ] = "local" # default to an in-memory cache redis_cache = None cache_config: Dict[str, Any] = {} @@ -556,9 +556,9 @@ def __init__( # noqa: PLR0915 ) ) - self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( - model_group_retry_policy - ) + self.model_group_retry_policy: Optional[ + Dict[str, RetryPolicy] + ] = model_group_retry_policy self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -1093,9 +1093,9 @@ def _update_kwargs_with_default_litellm_params( """ Adds default litellm params to kwargs, if set. """ - self.default_litellm_params[metadata_variable_name] = ( - self.default_litellm_params.pop("metadata", {}) - ) + self.default_litellm_params[ + metadata_variable_name + ] = self.default_litellm_params.pop("metadata", {}) for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None @@ -1678,14 +1678,16 @@ async def _prompt_management_factory( f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}" ) - model, messages, optional_params = ( - litellm_logging_object.get_chat_completion_prompt( - model=litellm_model, - messages=messages, - non_default_params=get_non_default_completion_params(kwargs=kwargs), - prompt_id=prompt_id, - prompt_variables=prompt_variables, - ) + ( + model, + messages, + optional_params, + ) = litellm_logging_object.get_chat_completion_prompt( + model=litellm_model, + messages=messages, + non_default_params=get_non_default_completion_params(kwargs=kwargs), + prompt_id=prompt_id, + prompt_variables=prompt_variables, ) kwargs = {**kwargs, **optional_params} @@ -2924,7 +2926,6 @@ async def aretrieve_batch( Future Improvement - cache the result. """ try: - filtered_model_list = self.get_model_list() if filtered_model_list is None: raise Exception("Router not yet initialized.") @@ -3211,11 +3212,11 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3247,11 +3248,11 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3282,11 +3283,12 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 e.message += "\n{}".format(error_message) if fallbacks is not None and model_group is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") - fallback_model_group, generic_fallback_idx = ( - get_fallback_model_group( - fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] - model_group=cast(str, model_group), - ) + ( + fallback_model_group, + generic_fallback_idx, + ) = get_fallback_model_group( + fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] + model_group=cast(str, model_group), ) ## if none, check for generic fallback if ( @@ -3444,11 +3446,12 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 """ Retry Logic """ - _healthy_deployments, _all_deployments = ( - await self._async_get_healthy_deployments( - model=kwargs.get("model") or "", - parent_otel_span=parent_otel_span, - ) + ( + _healthy_deployments, + _all_deployments, + ) = await self._async_get_healthy_deployments( + model=kwargs.get("model") or "", + parent_otel_span=parent_otel_span, ) # raises an exception if this error should not be retries @@ -3513,11 +3516,12 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 remaining_retries = num_retries - current_attempt _model: Optional[str] = kwargs.get("model") # type: ignore if _model is not None: - _healthy_deployments, _ = ( - await self._async_get_healthy_deployments( - model=_model, - parent_otel_span=parent_otel_span, - ) + ( + _healthy_deployments, + _, + ) = await self._async_get_healthy_deployments( + model=_model, + parent_otel_span=parent_otel_span, ) else: _healthy_deployments = [] @@ -3884,7 +3888,6 @@ def deployment_callback_on_failure( ) if exception_headers is not None: - _time_to_cooldown = ( litellm.utils._get_retry_after_from_exception_header( response_headers=exception_headers @@ -6131,7 +6134,6 @@ def _track_deployment_metrics( try: model_id = deployment.get("model_info", {}).get("id", None) if response is None: - # update self.deployment_stats if model_id is not None: self._update_usage( diff --git a/litellm/router_strategy/base_routing_strategy.py b/litellm/router_strategy/base_routing_strategy.py index a39d17e38627..ea87e25eba59 100644 --- a/litellm/router_strategy/base_routing_strategy.py +++ b/litellm/router_strategy/base_routing_strategy.py @@ -38,9 +38,9 @@ def __init__( except RuntimeError: # No event loop in current thread self._create_sync_thread(default_sync_interval) - self.in_memory_keys_to_update: set[str] = ( - set() - ) # Set with max size of 1000 keys + self.in_memory_keys_to_update: set[ + str + ] = set() # Set with max size of 1000 keys async def _increment_value_in_current_window( self, key: str, value: Union[int, float], ttl: int diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 4f123df28250..9e4001b67b9d 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -53,9 +53,9 @@ def __init__( self.dual_cache = dual_cache self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) - self.provider_budget_config: Optional[GenericBudgetConfigType] = ( - provider_budget_config - ) + self.provider_budget_config: Optional[ + GenericBudgetConfigType + ] = provider_budget_config self.deployment_budget_config: Optional[GenericBudgetConfigType] = None self.tag_budget_config: Optional[GenericBudgetConfigType] = None self._init_provider_budgets() @@ -94,11 +94,13 @@ async def async_filter_deployments( potential_deployments: List[Dict] = [] - cache_keys, provider_configs, deployment_configs = ( - await self._async_get_cache_keys_for_router_budget_limiting( - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, - ) + ( + cache_keys, + provider_configs, + deployment_configs, + ) = await self._async_get_cache_keys_for_router_budget_limiting( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, ) # Single cache read for all spend values @@ -114,17 +116,18 @@ async def async_filter_deployments( for idx, key in enumerate(cache_keys): spend_map[key] = float(current_spends[idx] or 0.0) - potential_deployments, deployment_above_budget_info = ( - self._filter_out_deployments_above_budget( - healthy_deployments=healthy_deployments, - provider_configs=provider_configs, - deployment_configs=deployment_configs, - spend_map=spend_map, - potential_deployments=potential_deployments, - request_tags=_get_tags_from_request_kwargs( - request_kwargs=request_kwargs - ), - ) + ( + potential_deployments, + deployment_above_budget_info, + ) = self._filter_out_deployments_above_budget( + healthy_deployments=healthy_deployments, + provider_configs=provider_configs, + deployment_configs=deployment_configs, + spend_map=spend_map, + potential_deployments=potential_deployments, + request_tags=_get_tags_from_request_kwargs( + request_kwargs=request_kwargs + ), ) if len(potential_deployments) == 0: diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index d1a46b7ea89e..9b1836586009 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -69,7 +69,6 @@ def pre_call_check(self, deployment: Dict) -> Optional[Dict]: Raises - RateLimitError if deployment over defined RPM limit """ try: - # ------------ # Setup values # ------------ diff --git a/litellm/router_utils/cooldown_callbacks.py b/litellm/router_utils/cooldown_callbacks.py index 54a016d3ec8d..5961a04feb4b 100644 --- a/litellm/router_utils/cooldown_callbacks.py +++ b/litellm/router_utils/cooldown_callbacks.py @@ -59,9 +59,9 @@ async def router_cooldown_event_callback( pass # get the prometheus logger from in memory loggers - prometheusLogger: Optional[PrometheusLogger] = ( - _get_prometheus_logger_from_callbacks() - ) + prometheusLogger: Optional[ + PrometheusLogger + ] = _get_prometheus_logger_from_callbacks() if prometheusLogger is not None: prometheusLogger.set_deployment_complete_outage( diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 729510574a87..c6804b1ad4cb 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -105,13 +105,11 @@ def _return_pattern_matched_deployments( new_deployments = [] for deployment in deployments: new_deployment = copy.deepcopy(deployment) - new_deployment["litellm_params"]["model"] = ( - PatternMatchRouter.set_deployment_model_name( - matched_pattern=matched_pattern, - litellm_deployment_litellm_model=deployment["litellm_params"][ - "model" - ], - ) + new_deployment["litellm_params"][ + "model" + ] = PatternMatchRouter.set_deployment_model_name( + matched_pattern=matched_pattern, + litellm_deployment_litellm_model=deployment["litellm_params"]["model"], ) new_deployments.append(new_deployment) diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index fd89d6c5584c..327dbf3d1963 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -50,7 +50,6 @@ def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]): if use_aws_secret_manager is None or use_aws_secret_manager is False: return try: - cls.validate_environment() litellm.secret_manager_client = cls() litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER diff --git a/litellm/types/integrations/arize_phoenix.py b/litellm/types/integrations/arize_phoenix.py index 4566022d17f2..a8a1fed5a6b1 100644 --- a/litellm/types/integrations/arize_phoenix.py +++ b/litellm/types/integrations/arize_phoenix.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING, Literal, Optional from pydantic import BaseModel + from .arize import Protocol + class ArizePhoenixConfig(BaseModel): otlp_auth_headers: Optional[str] = None protocol: Protocol - endpoint: str + endpoint: str diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 1c5552637c6e..3ba5a3a4e0fb 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -722,12 +722,12 @@ def __init__(self, **kwargs): class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." - learning_rate_multiplier: Optional[Union[str, float]] = ( - None # Scaling factor for the learning rate - ) - n_epochs: Optional[Union[str, int]] = ( - None # "The number of epochs to train the model for" - ) + learning_rate_multiplier: Optional[ + Union[str, float] + ] = None # Scaling factor for the learning rate + n_epochs: Optional[ + Union[str, int] + ] = None # "The number of epochs to train the model for" class FineTuningJobCreate(BaseModel): @@ -754,18 +754,18 @@ class FineTuningJobCreate(BaseModel): model: str # "The name of the model to fine-tune." training_file: str # "The ID of an uploaded file that contains training data." - hyperparameters: Optional[Hyperparameters] = ( - None # "The hyperparameters used for the fine-tuning job." - ) - suffix: Optional[str] = ( - None # "A string of up to 18 characters that will be added to your fine-tuned model name." - ) - validation_file: Optional[str] = ( - None # "The ID of an uploaded file that contains validation data." - ) - integrations: Optional[List[str]] = ( - None # "A list of integrations to enable for your fine-tuning job." - ) + hyperparameters: Optional[ + Hyperparameters + ] = None # "The hyperparameters used for the fine-tuning job." + suffix: Optional[ + str + ] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." + validation_file: Optional[ + str + ] = None # "The ID of an uploaded file that contains validation data." + integrations: Optional[ + List[str] + ] = None # "A list of integrations to enable for your fine-tuning job." seed: Optional[int] = None # "The seed controls the reproducibility of the job." diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index 8e2a8cc33473..fb6dae0d1df9 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -21,7 +21,6 @@ class RerankRequest(BaseModel): max_tokens_per_doc: Optional[int] = None - class OptionalRerankParams(TypedDict, total=False): query: str top_n: Optional[int] @@ -60,9 +59,9 @@ class RerankResponseResult(TypedDict, total=False): class RerankResponse(BaseModel): id: Optional[str] = None - results: Optional[List[RerankResponseResult]] = ( - None # Contains index and relevance_score - ) + results: Optional[ + List[RerankResponseResult] + ] = None # Contains index and relevance_score meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr diff --git a/litellm/types/router.py b/litellm/types/router.py index dcd547def294..45a8a3fcf6db 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -95,18 +95,16 @@ class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = ( - False # used for proxy - to separate models which are stored in the db vs. config. - ) + db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None created_at: Optional[datetime.datetime] = None created_by: Optional[str] = None - base_model: Optional[str] = ( - None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking - ) + base_model: Optional[ + str + ] = None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking tier: Optional[Literal["free", "paid"]] = None """ @@ -171,12 +169,12 @@ class GenericLiteLLMParams(CredentialLiteLLMParams): custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None - timeout: Optional[Union[float, str, httpx.Timeout]] = ( - None # if str, pass in as os.environ/ - ) - stream_timeout: Optional[Union[float, str]] = ( - None # timeout when making stream=True calls, if str, pass in as os.environ/ - ) + timeout: Optional[ + Union[float, str, httpx.Timeout] + ] = None # if str, pass in as os.environ/ + stream_timeout: Optional[ + Union[float, str] + ] = None # timeout when making stream=True calls, if str, pass in as os.environ/ max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None @@ -251,9 +249,9 @@ def __init__( if max_retries is not None and isinstance(max_retries, str): max_retries = int(max_retries) # cast to int # We need to keep max_retries in args since it's a parameter of GenericLiteLLMParams - args["max_retries"] = ( - max_retries # Put max_retries back in args after popping it - ) + args[ + "max_retries" + ] = max_retries # Put max_retries back in args after popping it super().__init__(**args, **params) def __contains__(self, key): @@ -577,7 +575,6 @@ class AssistantsTypedDict(TypedDict): class FineTuningConfig(BaseModel): - custom_llm_provider: Literal["azure", "openai"] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index fe6330f8bdff..7f84a41cd555 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -482,7 +482,6 @@ def __setitem__(self, key, value): class ChatCompletionAudioResponse(ChatCompletionAudio): - def __init__( self, data: str, @@ -927,7 +926,6 @@ def __init__( self.finish_reason = None self.index = index if delta is not None: - if isinstance(delta, Delta): self.delta = delta elif isinstance(delta, dict): @@ -961,7 +959,6 @@ def __setitem__(self, key, value): class StreamingChatCompletionChunk(OpenAIChatCompletionChunk): def __init__(self, **kwargs): - new_choices = [] for choice in kwargs["choices"]: new_choice = StreamingChoices(**choice).model_dump() diff --git a/litellm/utils.py b/litellm/utils.py index 3c8b6667f93c..777352ed3456 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -482,7 +482,6 @@ def get_dynamic_callbacks( def function_setup( # noqa: PLR0915 original_function: str, rules_obj, start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. - ### NOTICES ### from litellm import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import set_callbacks @@ -504,9 +503,9 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( - kwargs.pop("callbacks", None) - ) + dynamic_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = kwargs.pop("callbacks", None) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1190,9 +1189,9 @@ def wrapper(*args, **kwargs): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs["retry_policy"] = ( - reset_retry_policy() - ) # prevent infinite loops + kwargs[ + "retry_policy" + ] = reset_retry_policy() # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -1404,7 +1403,6 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915 if ( num_retries and not _is_litellm_router_call ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying - try: litellm.num_retries = ( None # set retries to None to prevent infinite loops @@ -1425,7 +1423,6 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915 and context_window_fallback_dict and model in context_window_fallback_dict ): - if len(args) > 0: args[0] = context_window_fallback_dict[model] # type: ignore else: @@ -1521,7 +1518,6 @@ def _select_tokenizer( @lru_cache(maxsize=128) def _select_tokenizer_helper(model: str) -> SelectTokenizerResponse: - if litellm.disable_hf_tokenizer_download is True: return _return_openai_tokenizer(model) @@ -2990,16 +2986,16 @@ def get_optional_params( # noqa: PLR0915 True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("tools") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("tools") non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("functions") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("functions") elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -3022,10 +3018,10 @@ def get_optional_params( # noqa: PLR0915 if "response_format" in non_default_params: if provider_config is not None: - non_default_params["response_format"] = ( - provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] - ) + non_default_params[ + "response_format" + ] = provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -3177,7 +3173,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "replicate": - optional_params = litellm.ReplicateConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3211,7 +3206,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "together_ai": - optional_params = litellm.TogetherAIConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3279,7 +3273,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "vertex_ai": - if model in litellm.vertex_mistral_models: if "codestral" in model: optional_params = ( @@ -3358,7 +3351,6 @@ def _check_valid_arg(supported_params: List[str]): elif "anthropic" in bedrock_base_model and bedrock_route == "invoke": if bedrock_base_model.startswith("anthropic.claude-3"): - optional_params = ( litellm.AmazonAnthropicClaude3Config().map_openai_params( non_default_params=non_default_params, @@ -3395,7 +3387,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "cloudflare": - optional_params = litellm.CloudflareChatConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -3407,7 +3398,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "ollama": - optional_params = litellm.OllamaConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3419,7 +3409,6 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "ollama_chat": - optional_params = litellm.OllamaChatConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -4005,9 +3994,9 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: - _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( - response_obj.choices - ) + _choices: Union[ + List[Union[Choices, StreamingChoices]], List[StreamingChoices] + ] = response_obj.choices response_str = "" for choice in _choices: @@ -4405,7 +4394,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and model in litellm.model_cost: - key = model _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4416,7 +4404,6 @@ def _get_model_info_helper( # noqa: PLR0915 _model_info is None and combined_stripped_model_name in litellm.model_cost ): - key = combined_stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4424,7 +4411,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and stripped_model_name in litellm.model_cost: - key = stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4432,7 +4418,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and split_model in litellm.model_cost: - key = split_model _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( From 6397ffbb7265123fae956858eddc332f1291bc31 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:18:12 -0700 Subject: [PATCH 03/14] ci: reformat to fit black --- litellm/caching/redis_semantic_cache.py | 13 +++-- .../amazon_nova_canvas_transformation.py | 52 ++++++++++++------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index 6ebd1c8f33a9..c76f27377d8e 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -143,11 +143,14 @@ def _get_embedding(self, prompt: str) -> List[float]: List[float]: The embedding vector """ # Create an embedding from prompt - embedding_response = cast(EmbeddingResponse, litellm.embedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, - )) + embedding_response = cast( + EmbeddingResponse, + litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ), + ) embedding = embedding_response["data"][0]["embedding"] return embedding diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py index 68ddf4ca880e..8575b3baed59 100644 --- a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -68,19 +68,27 @@ def transform_request_body( image_generation_config = optional_params.pop("imageGenerationConfig", {}) image_generation_config = {**image_generation_config, **optional_params} if task_type == "TEXT_IMAGE": - text_to_image_params: Dict[str, Any] = image_generation_config.pop("textToImageParams", {}) + text_to_image_params: Dict[str, Any] = image_generation_config.pop( + "textToImageParams", {} + ) text_to_image_params = {"text": text, **text_to_image_params} - try: + try: text_to_image_params_typed = AmazonNovaCanvasTextToImageParams( - **text_to_image_params + **text_to_image_params ) except Exception as e: - raise ValueError(f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}") - - try: - image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(**image_generation_config) + raise ValueError( + f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}" + ) + + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig( + **image_generation_config + ) except Exception as e: - raise ValueError(f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}") + raise ValueError( + f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}" + ) return AmazonNovaCanvasTextToImageRequest( textToImageParams=text_to_image_params_typed, @@ -88,24 +96,32 @@ def transform_request_body( imageGenerationConfig=image_generation_config_typed, ) if task_type == "COLOR_GUIDED_GENERATION": - color_guided_generation_params: Dict[str, Any] = image_generation_config.pop( - "colorGuidedGenerationParams", {} - ) + color_guided_generation_params: Dict[ + str, Any + ] = image_generation_config.pop("colorGuidedGenerationParams", {}) color_guided_generation_params = { "text": text, **color_guided_generation_params, } - try: - color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams( - **color_guided_generation_params + try: + color_guided_generation_params_typed = ( + AmazonNovaCanvasColorGuidedGenerationParams( + **color_guided_generation_params + ) ) except Exception as e: - raise ValueError(f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}") + raise ValueError( + f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}" + ) - try: - image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(**image_generation_config) + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig( + **image_generation_config + ) except Exception as e: - raise ValueError(f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}") + raise ValueError( + f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}" + ) return AmazonNovaCanvasColorGuidedRequest( taskType=task_type, From c285a02bee34cbb1ebc6e33aaecd83f85c9eb380 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:19:23 -0700 Subject: [PATCH 04/14] ci(test-litellm.yml): make tests run clear --- .github/workflows/test-litellm.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-litellm.yml b/.github/workflows/test-litellm.yml index 09de0ff1105c..12d09725ed1a 100644 --- a/.github/workflows/test-litellm.yml +++ b/.github/workflows/test-litellm.yml @@ -1,4 +1,4 @@ -name: LiteLLM Tests +name: LiteLLM Mock Tests (folder - tests/litellm) on: pull_request: From 5375fe94ec2c8db5aecadc980f658e6207cc66fa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:22:25 -0700 Subject: [PATCH 05/14] build(pyproject.toml): add ruff --- poetry.lock | 28 +++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 9c42278c14bf..e36808a56ec5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3326,6 +3326,32 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.1.15" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5fe8d54df166ecc24106db7dd6a68d44852d14eb0729ea4672bb4d96c320b7df"}, + {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f0bfbb53c4b4de117ac4d6ddfd33aa5fc31beeaa21d23c45c6dd249faf9126f"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0d432aec35bfc0d800d4f70eba26e23a352386be3a6cf157083d18f6f5881c8"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9405fa9ac0e97f35aaddf185a1be194a589424b8713e3b97b762336ec79ff807"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66ec24fe36841636e814b8f90f572a8c0cb0e54d8b5c2d0e300d28a0d7bffec"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6f8ad828f01e8dd32cc58bc28375150171d198491fc901f6f98d2a39ba8e3ff5"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86811954eec63e9ea162af0ffa9f8d09088bab51b7438e8b6488b9401863c25e"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd4025ac5e87d9b80e1f300207eb2fd099ff8200fa2320d7dc066a3f4622dc6b"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b17b93c02cdb6aeb696effecea1095ac93f3884a49a554a9afa76bb125c114c1"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ddb87643be40f034e97e97f5bc2ef7ce39de20e34608f3f829db727a93fb82c5"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:abf4822129ed3a5ce54383d5f0e964e7fef74a41e48eb1dfad404151efc130a2"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6c629cf64bacfd136c07c78ac10a54578ec9d1bd2a9d395efbee0935868bf852"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1bab866aafb53da39c2cadfb8e1c4550ac5340bb40300083eb8967ba25481447"}, + {file = "ruff-0.1.15-py3-none-win32.whl", hash = "sha256:2417e1cb6e2068389b07e6fa74c306b2810fe3ee3476d5b8a96616633f40d14f"}, + {file = "ruff-0.1.15-py3-none-win_amd64.whl", hash = "sha256:3837ac73d869efc4182d9036b1405ef4c73d9b1f88da2413875e34e0d6919587"}, + {file = "ruff-0.1.15-py3-none-win_arm64.whl", hash = "sha256:9a933dfb1c14ec7a33cceb1e49ec4a16b51ce3c20fd42663198746efc0427360"}, + {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, +] + [[package]] name = "s3transfer" version = "0.10.4" @@ -4098,4 +4124,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "efd348b2920530f18696f75f684026be542366eeec2f0b13f53ed9277a71c56a" +content-hash = "0beef7d83165f7e0e1608f7877ef6c268ccd37ab6bf796c5b3b57d220c3cf538" diff --git a/pyproject.toml b/pyproject.toml index fd17cdca178d..44b6f40c718e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-asyncio = "^0.21.1" respx = "^0.20.2" +ruff = "^0.1.0" types-requests = "*" types-setuptools = "*" types-redis = "*" From 185a20213b173214255917317f4ae3bc6b330566 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:25:22 -0700 Subject: [PATCH 06/14] fix: fix ruff checks --- litellm/caching/s3_cache.py | 2 +- litellm/integrations/s3.py | 2 +- litellm/proxy/proxy_server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/caching/s3_cache.py b/litellm/caching/s3_cache.py index 301591c64fec..c02e10913690 100644 --- a/litellm/caching/s3_cache.py +++ b/litellm/caching/s3_cache.py @@ -123,7 +123,7 @@ def get_cache(self, key, **kwargs): ) # Convert string to dictionary except Exception: cached_response = ast.literal_eval(cached_response) - if type(cached_response) is not dict: + if not isinstance(cached_response, dict): cached_response = dict(cached_response) verbose_logger.debug( f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 4a0c27354f4f..01b9248e0315 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -38,7 +38,7 @@ def __init__( if litellm.s3_callback_params is not None: # read in .env variables - example os.environ/AWS_BUCKET_NAME for key, value in litellm.s3_callback_params.items(): - if type(value) is str and value.startswith("os.environ/"): + if isinstance(value, str) and value.startswith("os.environ/"): litellm.s3_callback_params[key] = litellm.get_secret(value) # now set s3 params from litellm.s3_logger_params s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 026e9a42cece..756074521e37 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1588,7 +1588,7 @@ async def load_config( # noqa: PLR0915 # users can pass os.environ/ variables on the proxy - we should read them from the env for key, value in cache_params.items(): - if type(value) is str and value.startswith("os.environ/"): + if isinstance(value, str) and value.startswith("os.environ/"): cache_params[key] = get_secret(value) ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables From 22c6c708ae58a448b80eaf75b4a185559fa3b5f8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:30:21 -0700 Subject: [PATCH 07/14] build(mypy/): fix mypy linting errors --- litellm/_service_logger.py | 2 +- litellm/caching/base_cache.py | 4 ++-- litellm/caching/disk_cache.py | 4 ++-- litellm/caching/dual_cache.py | 4 ++-- litellm/caching/redis_cache.py | 2 +- litellm/caching/redis_cluster_cache.py | 4 ++-- litellm/integrations/arize/_utils.py | 4 ++-- litellm/integrations/arize/arize.py | 2 +- litellm/integrations/arize/arize_phoenix.py | 4 ++-- litellm/integrations/custom_logger.py | 2 +- litellm/integrations/langtrace.py | 4 ++-- litellm/integrations/opentelemetry.py | 2 +- litellm/litellm_core_utils/core_helpers.py | 2 +- litellm/proxy/_types.py | 2 +- litellm/proxy/auth/auth_checks.py | 4 ++-- litellm/proxy/auth/auth_exception_handler.py | 4 ++-- litellm/proxy/hooks/parallel_request_limiter.py | 2 +- litellm/proxy/proxy_server.py | 2 +- litellm/proxy/utils.py | 2 +- litellm/router.py | 2 +- litellm/router_strategy/lowest_latency.py | 2 +- litellm/router_strategy/lowest_tpm_rpm_v2.py | 2 +- litellm/router_utils/cooldown_cache.py | 4 ++-- litellm/router_utils/cooldown_handlers.py | 2 +- litellm/router_utils/handle_error.py | 4 ++-- litellm/router_utils/prompt_caching_cache.py | 4 ++-- 26 files changed, 38 insertions(+), 38 deletions(-) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0b4f22e210bc..8f835bea8305 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] OTELClass = OpenTelemetry else: Span = Any diff --git a/litellm/caching/base_cache.py b/litellm/caching/base_cache.py index 7109951d1599..5140b390f761 100644 --- a/litellm/caching/base_cache.py +++ b/litellm/caching/base_cache.py @@ -9,12 +9,12 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/disk_cache.py b/litellm/caching/disk_cache.py index abf3203f507a..413ac2932d3d 100644 --- a/litellm/caching/disk_cache.py +++ b/litellm/caching/disk_cache.py @@ -1,12 +1,12 @@ import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from .base_cache import BaseCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 5f598f7d7036..8bef3337587a 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -12,7 +12,7 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import litellm from litellm._logging import print_verbose, verbose_logger @@ -24,7 +24,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 29fa44715374..63cd4d095987 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -34,7 +34,7 @@ cluster_pipeline = ClusterPipeline async_redis_client = Redis async_redis_cluster_client = RedisCluster - Span = _Span + Span = Union[_Span, Any] else: pipeline = Any cluster_pipeline = Any diff --git a/litellm/caching/redis_cluster_cache.py b/litellm/caching/redis_cluster_cache.py index 2e7d1de17f30..21c3ab0366b9 100644 --- a/litellm/caching/redis_cluster_cache.py +++ b/litellm/caching/redis_cluster_cache.py @@ -5,7 +5,7 @@ - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created """ -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union from litellm.caching.redis_cache import RedisCache @@ -16,7 +16,7 @@ pipeline = Pipeline async_redis_client = Redis - Span = _Span + Span = Union[_Span, Any] else: pipeline = Any async_redis_client = Any diff --git a/litellm/integrations/arize/_utils.py b/litellm/integrations/arize/_utils.py index 487304cce4f5..5a090968b4fa 100644 --- a/litellm/integrations/arize/_utils.py +++ b/litellm/integrations/arize/_utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from litellm._logging import verbose_logger from litellm.litellm_core_utils.safe_json_dumps import safe_dumps @@ -7,7 +7,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/arize/arize.py b/litellm/integrations/arize/arize.py index 7aa412eef52e..03b6966809ce 100644 --- a/litellm/integrations/arize/arize.py +++ b/litellm/integrations/arize/arize.py @@ -19,7 +19,7 @@ from litellm.types.integrations.arize import Protocol as _Protocol Protocol = _Protocol - Span = _Span + Span = Union[_Span, Any] else: Protocol = Any Span = Any diff --git a/litellm/integrations/arize/arize_phoenix.py b/litellm/integrations/arize/arize_phoenix.py index b2f77522241e..2b4909885a36 100644 --- a/litellm/integrations/arize/arize_phoenix.py +++ b/litellm/integrations/arize/arize_phoenix.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from litellm._logging import verbose_logger from litellm.integrations.arize import _utils @@ -14,7 +14,7 @@ Protocol = _Protocol OpenTelemetryConfig = _OpenTelemetryConfig - Span = _Span + Span = Union[_Span, Any] else: Protocol = Any OpenTelemetryConfig = Any diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 6f1ec88d0104..ddb8094285ac 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/langtrace.py b/litellm/integrations/langtrace.py index 51cd272ff1c1..ac1069f440e6 100644 --- a/litellm/integrations/langtrace.py +++ b/litellm/integrations/langtrace.py @@ -1,12 +1,12 @@ import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from litellm.proxy._types import SpanAttributes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 177b2ae02bd7..47571ec9f874 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -23,7 +23,7 @@ ) from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth - Span = _Span + Span = Union[_Span, Any] SpanExporter = _SpanExporter UserAPIKeyAuth = _UserAPIKeyAuth ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 2036b93692e0..275c53ad3083 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 394f49df7a23..16b45f38377e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 4fd718351962..ddd1008bd039 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -11,7 +11,7 @@ import asyncio import re import time -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast from fastapi import Request, status from pydantic import BaseModel @@ -49,7 +49,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index 5dd30a075849..7c9765514148 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from fastapi import HTTPException, Request, status @@ -17,7 +17,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 83b3c7179aca..242c013d6773 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -22,7 +22,7 @@ from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache - Span = _Span + Span = Union[_Span, Any] InternalUsageCache = _InternalUsageCache else: Span = Any diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 756074521e37..b9a29b81f48f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -36,7 +36,7 @@ from litellm.integrations.opentelemetry import OpenTelemetry - Span = _Span + Span = Union[_Span, Any] else: Span = Any OpenTelemetry = Any diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7c73c45fb1b8..900b26f3f1a0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -78,7 +78,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router.py b/litellm/router.py index d0755f9a0d36..78ad2afe1a9a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -148,7 +148,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index b049c942642f..55ca98843d7c 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 9b1836586009..9e6c139314f9 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index f096b026c0a2..13d6318fc4f7 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -3,7 +3,7 @@ """ import time -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union from litellm import verbose_logger from litellm.caching.caching import DualCache @@ -12,7 +12,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_utils/cooldown_handlers.py b/litellm/router_utils/cooldown_handlers.py index 52babc27f2c7..ed9c2dd22978 100644 --- a/litellm/router_utils/cooldown_handlers.py +++ b/litellm/router_utils/cooldown_handlers.py @@ -29,7 +29,7 @@ from litellm.router import Router as _Router LitellmRouter = _Router - Span = _Span + Span = Union[_Span, Any] else: LitellmRouter = Any Span = Any diff --git a/litellm/router_utils/handle_error.py b/litellm/router_utils/handle_error.py index 132440cbc3c5..c331da70acf2 100644 --- a/litellm/router_utils/handle_error.py +++ b/litellm/router_utils/handle_error.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from litellm._logging import verbose_router_logger from litellm.router_utils.cooldown_handlers import ( @@ -13,7 +13,7 @@ from litellm.router import Router as _Router LitellmRouter = _Router - Span = _Span + Span = Union[_Span, Any] else: LitellmRouter = Any Span = Any diff --git a/litellm/router_utils/prompt_caching_cache.py b/litellm/router_utils/prompt_caching_cache.py index 1bf686d694a2..6a96b85e8aff 100644 --- a/litellm/router_utils/prompt_caching_cache.py +++ b/litellm/router_utils/prompt_caching_cache.py @@ -4,7 +4,7 @@ import hashlib import json -from typing import TYPE_CHECKING, Any, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union from litellm.caching.caching import DualCache from litellm.caching.in_memory_cache import InMemoryCache @@ -16,7 +16,7 @@ from litellm.router import Router litellm_router = Router - Span = _Span + Span = Union[_Span, Any] else: Span = Any litellm_router = Any From 0151f09d723060fd3ca71bbf7c0063ab34fbe917 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:32:54 -0700 Subject: [PATCH 08/14] fix(hashicorp_secret_manager.py): fix passing cert for tls auth --- litellm/secret_managers/hashicorp_secret_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/secret_managers/hashicorp_secret_manager.py b/litellm/secret_managers/hashicorp_secret_manager.py index a3d129f89ce6..859b28588181 100644 --- a/litellm/secret_managers/hashicorp_secret_manager.py +++ b/litellm/secret_managers/hashicorp_secret_manager.py @@ -90,9 +90,9 @@ def _auth_via_tls_cert(self) -> str: headers["X-Vault-Namespace"] = self.vault_namespace try: # We use the client cert and key for mutual TLS - resp = httpx.post( + client = httpx.Client(cert=(self.tls_cert_path, self.tls_key_path)) + resp = client.post( login_url, - cert=(self.tls_cert_path, self.tls_key_path), headers=headers, json=self._get_tls_cert_auth_body(), ) From fca4386142845b6346b4151c5bc1e51af8394552 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 09:57:26 -0700 Subject: [PATCH 09/14] build(mypy/): resolve all mypy errors --- litellm/integrations/opentelemetry.py | 12 ++++++------ .../image/amazon_nova_canvas_transformation.py | 8 ++++---- poetry.lock | 2 +- pyproject.toml | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 47571ec9f874..f0a083dcb627 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast import litellm from litellm._logging import verbose_logger @@ -24,9 +24,9 @@ from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth Span = Union[_Span, Any] - SpanExporter = _SpanExporter - UserAPIKeyAuth = _UserAPIKeyAuth - ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload + SpanExporter = Union[_SpanExporter, Any] + UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any] + ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any] else: Span = Any SpanExporter = Any @@ -839,12 +839,12 @@ def _get_span_processor(self, dynamic_headers: Optional[dict] = None): headers=dynamic_headers or self.OTEL_HEADERS ) - if isinstance(self.OTEL_EXPORTER, SpanExporter): + if hasattr(self.OTEL_EXPORTER, "export"): # Check if it has the export method that SpanExporter requires verbose_logger.debug( "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s", self.OTEL_EXPORTER, ) - return SimpleSpanProcessor(self.OTEL_EXPORTER) + return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER)) if self.OTEL_EXPORTER == "console": verbose_logger.debug( diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py index 8575b3baed59..1b6110b8ad9a 100644 --- a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -74,8 +74,8 @@ def transform_request_body( text_to_image_params = {"text": text, **text_to_image_params} try: text_to_image_params_typed = AmazonNovaCanvasTextToImageParams( - **text_to_image_params - ) + **text_to_image_params # type: ignore + ) except Exception as e: raise ValueError( f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}" @@ -105,8 +105,8 @@ def transform_request_body( } try: color_guided_generation_params_typed = ( - AmazonNovaCanvasColorGuidedGenerationParams( - **color_guided_generation_params + AmazonNovaCanvasColorGuidedGenerationParams( + **color_guided_generation_params # type: ignore ) ) except Exception as e: diff --git a/poetry.lock b/poetry.lock index e36808a56ec5..d659d669522f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4124,4 +4124,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "0beef7d83165f7e0e1608f7877ef6c268ccd37ab6bf796c5b3b57d220c3cf538" +content-hash = "36792478ff4afec5c8e748caf9b2ae6bebf3dd223e78bea2626b6589ef3277e4" diff --git a/pyproject.toml b/pyproject.toml index 44b6f40c718e..5eb4a7116020 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ Documentation = "https://docs.litellm.ai" [tool.poetry.dependencies] python = ">=3.8.1,<4.0, !=3.9.7" httpx = ">=0.23.0" -openai = ">=1.66.1" +openai = ">=1.68.2" python-dotenv = ">=0.2.0" tiktoken = ">=0.7.0" importlib-metadata = ">=6.8.0" From 0184bed203dd23b075a41a1417fbe52b2970d8e1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 10:04:29 -0700 Subject: [PATCH 10/14] test: update test --- .../hashicorp_secret_manager.py | 2 +- tests/litellm_utils_tests/test_hashicorp.py | 36 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/litellm/secret_managers/hashicorp_secret_manager.py b/litellm/secret_managers/hashicorp_secret_manager.py index 859b28588181..e0b4a08ce836 100644 --- a/litellm/secret_managers/hashicorp_secret_manager.py +++ b/litellm/secret_managers/hashicorp_secret_manager.py @@ -34,7 +34,7 @@ def __init__(self): # Validate environment if not self.vault_token: raise ValueError( - "Missing Vault token. Please set VAULT_TOKEN in your environment." + "Missing Vault token. Please set HCP_VAULT_TOKEN in your environment." ) litellm.secret_manager_client = self diff --git a/tests/litellm_utils_tests/test_hashicorp.py b/tests/litellm_utils_tests/test_hashicorp.py index 612af5a79cd2..c61168c99dd1 100644 --- a/tests/litellm_utils_tests/test_hashicorp.py +++ b/tests/litellm_utils_tests/test_hashicorp.py @@ -165,19 +165,26 @@ async def test_hashicorp_secret_manager_delete_secret(): ) -def test_hashicorp_secret_manager_tls_cert_auth(): - with patch("httpx.post") as mock_post: - # Configure the mock response for TLS auth - mock_auth_response = MagicMock() - mock_auth_response.json.return_value = { +def test_hashicorp_secret_manager_tls_cert_auth(monkeypatch): + monkeypatch.setenv("HCP_VAULT_TOKEN", "test-client-token-12345") + print("HCP_VAULT_TOKEN=", os.getenv("HCP_VAULT_TOKEN")) + # Mock both httpx.post and httpx.Client + with patch("httpx.Client") as mock_client: + # Configure the mock client and response + mock_response = MagicMock() + mock_response.json.return_value = { "auth": { "client_token": "test-client-token-12345", "lease_duration": 3600, "renewable": True, } } - mock_auth_response.raise_for_status.return_value = None - mock_post.return_value = mock_auth_response + mock_response.raise_for_status.return_value = None + + # Configure the mock client's post method + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client.return_value = mock_client_instance # Create a new instance with TLS cert config test_manager = HashicorpSecretManager() @@ -185,19 +192,22 @@ def test_hashicorp_secret_manager_tls_cert_auth(): test_manager.tls_key_path = "key.pem" test_manager.vault_cert_role = "test-role" test_manager.vault_namespace = "test-namespace" + # Test the TLS auth method token = test_manager._auth_via_tls_cert() - # Verify the token and request parameters + # Verify the token assert token == "test-client-token-12345" - mock_post.assert_called_once_with( + + # Verify Client was created with correct cert tuple + mock_client.assert_called_once_with(cert=("cert.pem", "key.pem")) + + # Verify post was called with correct parameters + mock_client_instance.post.assert_called_once_with( f"{test_manager.vault_addr}/v1/auth/cert/login", - cert=("cert.pem", "key.pem"), headers={"X-Vault-Namespace": "test-namespace"}, json={"name": "test-role"}, ) # Verify the token was cached - assert ( - test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345" - ) + assert test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345" From 7e76de59179338269348d549606a06149c14fcc5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 10:06:19 -0700 Subject: [PATCH 11/14] fix: fix black formatting --- litellm/integrations/opentelemetry.py | 4 +++- .../bedrock/image/amazon_nova_canvas_transformation.py | 10 ++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index f0a083dcb627..f4fe40738bab 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -839,7 +839,9 @@ def _get_span_processor(self, dynamic_headers: Optional[dict] = None): headers=dynamic_headers or self.OTEL_HEADERS ) - if hasattr(self.OTEL_EXPORTER, "export"): # Check if it has the export method that SpanExporter requires + if hasattr( + self.OTEL_EXPORTER, "export" + ): # Check if it has the export method that SpanExporter requires verbose_logger.debug( "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s", self.OTEL_EXPORTER, diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py index 1b6110b8ad9a..b331dd1b1dc2 100644 --- a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -74,8 +74,8 @@ def transform_request_body( text_to_image_params = {"text": text, **text_to_image_params} try: text_to_image_params_typed = AmazonNovaCanvasTextToImageParams( - **text_to_image_params # type: ignore - ) + **text_to_image_params # type: ignore + ) except Exception as e: raise ValueError( f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}" @@ -104,10 +104,8 @@ def transform_request_body( **color_guided_generation_params, } try: - color_guided_generation_params_typed = ( - AmazonNovaCanvasColorGuidedGenerationParams( - **color_guided_generation_params # type: ignore - ) + color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams( + **color_guided_generation_params # type: ignore ) except Exception as e: raise ValueError( From 4863abe268a8da509f28a1ec530ac1c89e510049 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 10:07:33 -0700 Subject: [PATCH 12/14] build(pre-commit-config.yaml): use poetry run black --- .pre-commit-config.yaml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4818ca6ca0fc..bceedb41aa1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,10 +14,12 @@ repos: types: [python] files: litellm/.*\.py exclude: ^litellm/__init__.py$ -# - repo: https://github.com/psf/black -# rev: 24.2.0 -# hooks: -# - id: black + - id: black + name: black + entry: poetry run black + language: system + types: [python] + files: litellm/.*\.py - repo: https://github.com/pycqa/flake8 rev: 7.0.0 # The version of flake8 to use hooks: From 91bc7c4d72862ed21c7ef678b4cd2ea09981035e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 10:08:31 -0700 Subject: [PATCH 13/14] fix(proxy_server.py): fix linting error --- litellm/proxy/proxy_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b9a29b81f48f..f59d11718184 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -18,6 +18,7 @@ List, Optional, Tuple, + Union, cast, get_args, get_origin, From d44a3fb851e698d3da4ee07c0fd54a37fc59d1ba Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Mar 2025 10:47:56 -0700 Subject: [PATCH 14/14] fix: fix ruff safe representation error --- litellm/llms/sagemaker/completion/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index fcae9d6c03d3..296689c31caa 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -627,7 +627,7 @@ def embedding( response = client.invoke_endpoint( EndpointName={model}, ContentType="application/json", - Body={data}, # type: ignore + Body=f"{data!r}", # Use !r for safe representation CustomAttributes="accept_eula=true", )""" # type: ignore logging_obj.pre_call(