Skip to content

Return a copy from strict key removal to not break cache keys #9693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions litellm/caching/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,9 @@ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
"""
Get the preset cache key from kwargs["litellm_params"]

We use _get_preset_cache_keys for two reasons
Is set after the cache is first calculated in order to not mutate between request and response time,
in case the implementation mutates the original objects (and avoids doing duplicate key calculations)

1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
2. avoid doing duplicate / repeated work
"""
if kwargs:
if "litellm_params" in kwargs:
Expand Down
30 changes: 24 additions & 6 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,14 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())

# set up litellm_params, so that keys can be added (e.g. for tracking cache keys)
if "litellm_params" not in kwargs:
kwargs["litellm_params"] = {}
# without copying, something goes wrong deep in the cost logging,
# where metadata would be read from if litellm_params is None
if "metadata" in kwargs:
kwargs["litellm_params"]["metadata"] = kwargs["metadata"]

model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
is_completion_with_fallbacks = kwargs.get("fallbacks") is not None

Expand Down Expand Up @@ -2794,23 +2802,33 @@ def _remove_additional_properties(schema):

def _remove_strict_from_schema(schema):
"""
Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088
Recursively removes 'strict' from schema. Returns a copy, in order to not break cache keys, (so you should update your reference)

Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088,
"""
maybe_copy = None # make a copy to not break cache keys https://github.com/BerriAI/litellm/issues/9692
if isinstance(schema, dict):
# Remove the 'additionalProperties' key if it exists and is set to False
if "strict" in schema:
del schema["strict"]
maybe_copy = schema.copy()
del maybe_copy["strict"]

# Recursively process all dictionary values
for key, value in schema.items():
_remove_strict_from_schema(value)
result = _remove_strict_from_schema(value)
if result is not value:
maybe_copy = maybe_copy or schema.copy()
maybe_copy[key] = result

elif isinstance(schema, list):
# Recursively process all items in the list
for item in schema:
_remove_strict_from_schema(item)
for i, item in enumerate(schema):
result = _remove_strict_from_schema(item)
if result is not item:
maybe_copy = maybe_copy or list(schema)
maybe_copy[i] = result

return schema
return maybe_copy or schema


def _remove_unsupported_params(
Expand Down
105 changes: 105 additions & 0 deletions tests/litellm_utils_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import json
import sys
import time
from datetime import datetime
Expand Down Expand Up @@ -1910,6 +1911,43 @@ async def test_function(**kwargs):
== "gpt-4o-mini"
)

@pytest.mark.asyncio
async def test_cache_key_stability_with_mutation(monkeypatch):
from litellm.utils import client
import asyncio
from litellm.caching import Cache

# Set up in-memory cache
cache = Cache()
monkeypatch.setattr(litellm, "cache", cache)

# Create mock original function
mock_original = AsyncMock()

def side_effect(**kwargs):
print(f"kwargs: {kwargs}")
return litellm.ModelResponse(
model="vertex_ai/gemini-2.0-flash"
)
mock_original.side_effect = side_effect

# Apply decorator
@client
async def acompletion(**kwargs):
kwargs["messages"][0]["content"] = "mutated"
return await mock_original(**kwargs)

# Test kwargs
test_kwargs = {"model": "vertex_ai/gemini-2.0-flash", "messages": [{"role": "user", "content": "Hello, world!"}]}
original_kwargs = copy.deepcopy(test_kwargs)

# Call decorated function
await acompletion(**test_kwargs)
await asyncio.sleep(0.01)
await acompletion(**original_kwargs)

mock_original.assert_called_once()


def test_dict_to_response_format_helper():
from litellm.llms.base_llm.base_utils import _dict_to_response_format_helper
Expand Down Expand Up @@ -2122,3 +2160,70 @@ def test_get_provider_audio_transcription_config():
config = ProviderConfigManager.get_provider_audio_transcription_config(
model="whisper-1", provider=provider
)

def test_remove_strict_from_schema():
from litellm.utils import _remove_strict_from_schema

schema = { # This isn't maybe actually very realistic json schema, just slop full of stricts
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"strict": True,
"definitions": {
"address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"}
},
"required": ["street", "city"],
"strict": True
}
},
"properties": {
"name": {
"type": "string",
"strict": True
},
"age": {
"type": "integer"
},
"address": {
"$ref": "#/definitions/address"
},
"tags": {
"type": "array",
"items": {"type": "string"},
"strict": True
},
"contacts": {
"type": "array",
"items": {
"oneOf": [
{"type": "string"},
{
"type": "array",
"items": {
"type": "object",
"strict": True,
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
}
}
],
"strict": True
}
}
}
}
original_schema = copy.deepcopy(schema)
cleaned = _remove_strict_from_schema(schema)
assert "strict" not in json.dumps(cleaned)
# schema should be unchanged, (should copy instead of mutate)
# otherwise it breaks cache keys
# https://github.com/BerriAI/litellm/issues/9692
assert cleaned != original_schema
assert schema == original_schema


23 changes: 23 additions & 0 deletions tests/local_testing/test_unit_test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,26 @@ def test_generate_streaming_content():
assert chunk_count > 1

print(f"Number of chunks: {chunk_count}")

def test_caching_stable_with_mutation():
"""
Test that caching is stable with mutation during a request. This is to circumvent the cache miss when a provider
implementation mutates an argument of the original request (e.g. to normalize it for the specific provider).
Otherwise the response is stored under a different key than the cache is checked with.
"""
litellm.cache = Cache()
kwargs = {
"model": "o1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, world!"},
],
"temperature": 0.7,
"litellm_params": {}, # litellm_params must be set for this to work, otherwise it's not spread into the kwargs
}
cache_key = litellm.cache.get_cache_key(**kwargs)

# mutate kwargs
kwargs["messages"][0]["role"] = "developer"
cache_key_2 = litellm.cache.get_cache_key(**kwargs)
assert cache_key == cache_key_2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @adrianlyjak this is not desired behaviour - if the user changes an optional param, we do not want to return a cached response

Copy link
Contributor Author

@adrianlyjak adrianlyjak Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krrishdholakia I didn't change anything to make this particular test pass, this is actually the current functionality. This appears to be the existing intended behavior of the code, to memoize the cache key within a single request.

This test scenario is perhaps a little unrealistic, since the temperature itself can't get changed, as the kwargs are spread, copying the dict, however if a nested parameter such as the response schema is mutated between the start of the request and the response, then the same cache key is used. The related fix I implemented was to just ensure the litellm_params were initialized, (so the cache key is actually memoized)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the user changes an optional param

To be clear, as I understand, this would only be happening internally within integrations. I don't know of a way for the user to be modifying the request parameters after calling the completion (or other function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted the test for clarity, as modifying temperature isn't really an expected use case, instead I normalized the system -> developer role on a message

Loading