Skip to content

Add support to Vertex AI transformation for anyOf union type with null fields #9625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/test-litellm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: LiteLLM Tests

on:
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- uses: actions/checkout@v4

- name: Thank You Message
run: |
echo "### 🙏 Thank you for contributing to LiteLLM!" >> $GITHUB_STEP_SUMMARY
echo "Your PR is being tested now. We appreciate your help in making LiteLLM better!" >> $GITHUB_STEP_SUMMARY

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Install dependencies
run: |
poetry install --with dev,proxy-dev --extras proxy
poetry run pip install pytest-xdist

- name: Run tests
run: |
poetry run pytest tests/litellm -x -vv -n 4
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ tests/llm_translation/test_vertex_key.json
litellm/proxy/migrations/0_init/migration.sql
litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*
litellm/proxy/migrations/*config.yaml
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ help:
install-dev:
poetry install --with dev

install-proxy-dev:
poetry install --with dev,proxy-dev

lint: install-dev
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
cd litellm && poetry run mypy . --ignore-missing-imports
Expand Down
2 changes: 1 addition & 1 deletion deploy/charts/litellm-helm/templates/migrations-job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ spec:
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
ttlSecondsAfterFinished: {{ .Values.migrationJob.ttlSecondsAfterFinished }}
ttlSecondsAfterFinished: {{ .Values.migrationJob.ttlSecondsAfterFinished }}
backoffLimit: {{ .Values.migrationJob.backoffLimit }}
{{- end }}
60 changes: 57 additions & 3 deletions litellm/_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import sys
from datetime import datetime
from logging import Formatter

Expand Down Expand Up @@ -40,9 +41,56 @@ def format(self, record):
return json.dumps(json_record)


# Function to set up exception handlers for JSON logging
def _setup_json_exception_handlers(formatter):
# Create a handler with JSON formatting for exceptions
error_handler = logging.StreamHandler()
error_handler.setFormatter(formatter)

# Setup excepthook for uncaught exceptions
def json_excepthook(exc_type, exc_value, exc_traceback):
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exc_value),
args=(),
exc_info=(exc_type, exc_value, exc_traceback),
)
error_handler.handle(record)

sys.excepthook = json_excepthook

# Configure asyncio exception handler if possible
try:
import asyncio

def async_json_exception_handler(loop, context):
exception = context.get("exception")
if exception:
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exception),
args=(),
exc_info=None,
)
error_handler.handle(record)
else:
loop.default_exception_handler(context)

asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
except Exception:
pass


# Create a formatter and set it for the handler
if json_logs:
handler.setFormatter(JsonFormatter())
_setup_json_exception_handlers(JsonFormatter())
else:
formatter = logging.Formatter(
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
Expand All @@ -65,18 +113,24 @@ def _turn_on_json():
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())

# Define a list of the loggers to update
loggers = [verbose_router_logger, verbose_proxy_logger, verbose_logger]
# Define all loggers to update, including root logger
loggers = [logging.getLogger()] + [
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
]

# Iterate through each logger and update its handlers
for logger in loggers:
# Remove all existing handlers
for h in logger.handlers[:]:
logger.removeHandler(h)

# Add the new handler
logger.addHandler(handler)

# Set up exception handlers
_setup_json_exception_handlers(JsonFormatter())


def _turn_on_debug():
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
Expand Down
1 change: 1 addition & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
DEFAULT_MAX_RECURSE_DEPTH = 10
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
Expand Down
3 changes: 2 additions & 1 deletion litellm/litellm_core_utils/safe_json_dumps.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from typing import Any, Union
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH


def safe_dumps(data: Any, max_depth: int = 10) -> str:
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
"""
Recursively serialize data while detecting circular references.
If a circular reference is detected then a marker string is returned.
Expand Down
3 changes: 2 additions & 1 deletion litellm/litellm_core_utils/sensitive_data_masker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH


class SensitiveDataMasker:
Expand Down Expand Up @@ -39,7 +40,7 @@ def is_sensitive_key(self, key: str) -> bool:
return result

def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> Dict[str, Any]:
if depth >= max_depth:
return data
Expand Down
48 changes: 29 additions & 19 deletions litellm/llms/vertex_ai/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType

Expand Down Expand Up @@ -177,7 +178,7 @@ def _build_vertex_schema(parameters: dict):
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters)
convert_anyof_null_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
Expand Down Expand Up @@ -228,34 +229,43 @@ def unpack_defs(schema, defs):
continue


def convert_to_nullable(schema):
anyof = schema.pop("anyOf", None)
def convert_anyof_null_to_nullable(schema, depth=0):
if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise ValueError(
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
)
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
anyof = schema.get("anyOf", None)
if anyof is not None:
if len(anyof) != 2:
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
)
a, b = anyof
if a == {"type": "null"}:
schema.update(b)
elif b == {"type": "null"}:
schema.update(a)
else:
contains_null = False
for atype in anyof:
if atype == {"type": "null"}:
# remove null type
anyof.remove(atype)
contains_null = True

if len(anyof) == 0:
# Edge case: response schema with only null type present is invalid in Vertex AI
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
"Invalid input: AnyOf schema with only null type is not supported. "
"Please provide a non-null type."
)
schema["nullable"] = True


if contains_null:
# set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python
for atype in anyof:
atype["nullable"] = True


properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_to_nullable(value)
convert_anyof_null_to_nullable(value, depth=depth + 1)

items = schema.get("items", None)
if items is not None:
convert_to_nullable(items)
convert_anyof_null_to_nullable(items, depth=depth + 1)


def add_object_type(schema):
Expand Down
4 changes: 2 additions & 2 deletions litellm/llms/xai/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

import httpx

Expand All @@ -22,7 +22,7 @@ def get_base_model(model: str) -> Optional[str]:

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> list[str]:
) -> List[str]:
api_base = self.get_api_base(api_base)
api_key = self.get_api_key(api_key)
if api_base is None or api_key is None:
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ model_list:
litellm_settings:
num_retries: 0
callbacks: ["prometheus"]
json_logs: true

router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
Expand Down
5 changes: 4 additions & 1 deletion litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ModelResponseStream,
TextCompletionResponse,
)
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Expand Down Expand Up @@ -462,6 +463,8 @@ async def proxy_startup_event(app: FastAPI):
if premium_user is False:
premium_user = _license_check.is_premium()

## CHECK MASTER KEY IN ENVIRONMENT ##
master_key = get_secret_str("LITELLM_MASTER_KEY")
### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
Expand Down Expand Up @@ -1522,7 +1525,7 @@ async def save_config(self, new_config: dict):
yaml.dump(new_config, config_file, default_flow_style=False)

def _check_for_os_environ_vars(
self, config: dict, depth: int = 0, max_depth: int = 10
self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> dict:
"""
Check for os.environ/ variables in the config and replace them with the actual values.
Expand Down
Loading
Loading