Skip to content

Commit e85c46c

Browse files
author
Adrian Lyjak
committed
fix #9692. Keep cache key stable during mutation
A) Return a copy from strict key removal to not break cache keys B) Fix issue in existing cache key stabilizer that was not storing a stable key across request/response if no litellm_params existed in the request
1 parent 64bb89c commit e85c46c

File tree

4 files changed

+154
-9
lines changed

4 files changed

+154
-9
lines changed

litellm/caching/caching.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,9 @@ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
331331
"""
332332
Get the preset cache key from kwargs["litellm_params"]
333333
334-
We use _get_preset_cache_keys for two reasons
334+
Is set after the cache is first calculated in order to not mutate between request and response time,
335+
in case the implementation mutates the original objects (and avoids doing duplicate key calculations)
335336
336-
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
337-
2. avoid doing duplicate / repeated work
338337
"""
339338
if kwargs:
340339
if "litellm_params" in kwargs:

litellm/utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,14 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
12641264
if "litellm_call_id" not in kwargs:
12651265
kwargs["litellm_call_id"] = str(uuid.uuid4())
12661266

1267+
# set up litellm_params, so that keys can be added (e.g. for tracking cache keys)
1268+
if "litellm_params" not in kwargs:
1269+
kwargs["litellm_params"] = {}
1270+
# without copying, something goes wrong deep in the cost logging,
1271+
# where metadata would be read from if litellm_params is None
1272+
if "metadata" in kwargs:
1273+
kwargs["litellm_params"]["metadata"] = kwargs["metadata"]
1274+
12671275
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
12681276
is_completion_with_fallbacks = kwargs.get("fallbacks") is not None
12691277

@@ -2794,23 +2802,33 @@ def _remove_additional_properties(schema):
27942802

27952803
def _remove_strict_from_schema(schema):
27962804
"""
2797-
Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088
2805+
Recursively removes 'strict' from schema. Returns a copy, in order to not break cache keys, (so you should update your reference)
2806+
2807+
Relevant Issues: https://github.com/BerriAI/litellm/issues/6136, https://github.com/BerriAI/litellm/issues/6088,
27982808
"""
2809+
maybe_copy = None # make a copy to not break cache keys https://github.com/BerriAI/litellm/issues/9692
27992810
if isinstance(schema, dict):
28002811
# Remove the 'additionalProperties' key if it exists and is set to False
28012812
if "strict" in schema:
2802-
del schema["strict"]
2813+
maybe_copy = schema.copy()
2814+
del maybe_copy["strict"]
28032815

28042816
# Recursively process all dictionary values
28052817
for key, value in schema.items():
2806-
_remove_strict_from_schema(value)
2818+
result = _remove_strict_from_schema(value)
2819+
if result is not value:
2820+
maybe_copy = maybe_copy or schema.copy()
2821+
maybe_copy[key] = result
28072822

28082823
elif isinstance(schema, list):
28092824
# Recursively process all items in the list
2810-
for item in schema:
2811-
_remove_strict_from_schema(item)
2825+
for i, item in enumerate(schema):
2826+
result = _remove_strict_from_schema(item)
2827+
if result is not item:
2828+
maybe_copy = maybe_copy or list(schema)
2829+
maybe_copy[i] = result
28122830

2813-
return schema
2831+
return maybe_copy or schema
28142832

28152833

28162834
def _remove_unsupported_params(

tests/litellm_utils_tests/test_utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import sys
34
import time
45
from datetime import datetime
@@ -1910,6 +1911,43 @@ async def test_function(**kwargs):
19101911
== "gpt-4o-mini"
19111912
)
19121913

1914+
@pytest.mark.asyncio
1915+
async def test_cache_key_stability_with_mutation(monkeypatch):
1916+
from litellm.utils import client
1917+
import asyncio
1918+
from litellm.caching import Cache
1919+
1920+
# Set up in-memory cache
1921+
cache = Cache()
1922+
monkeypatch.setattr(litellm, "cache", cache)
1923+
1924+
# Create mock original function
1925+
mock_original = AsyncMock()
1926+
1927+
def side_effect(**kwargs):
1928+
print(f"kwargs: {kwargs}")
1929+
return litellm.ModelResponse(
1930+
model="vertex_ai/gemini-2.0-flash"
1931+
)
1932+
mock_original.side_effect = side_effect
1933+
1934+
# Apply decorator
1935+
@client
1936+
async def acompletion(**kwargs):
1937+
kwargs["messages"][0]["content"] = "mutated"
1938+
return await mock_original(**kwargs)
1939+
1940+
# Test kwargs
1941+
test_kwargs = {"model": "vertex_ai/gemini-2.0-flash", "messages": [{"role": "user", "content": "Hello, world!"}]}
1942+
original_kwargs = copy.deepcopy(test_kwargs)
1943+
1944+
# Call decorated function
1945+
await acompletion(**test_kwargs)
1946+
await asyncio.sleep(0.01)
1947+
await acompletion(**original_kwargs)
1948+
1949+
mock_original.assert_called_once()
1950+
19131951

19141952
def test_dict_to_response_format_helper():
19151953
from litellm.llms.base_llm.base_utils import _dict_to_response_format_helper
@@ -2122,3 +2160,70 @@ def test_get_provider_audio_transcription_config():
21222160
config = ProviderConfigManager.get_provider_audio_transcription_config(
21232161
model="whisper-1", provider=provider
21242162
)
2163+
2164+
def test_remove_strict_from_schema():
2165+
from litellm.utils import _remove_strict_from_schema
2166+
2167+
schema = { # This isn't maybe actually very realistic json schema, just slop full of stricts
2168+
"$schema": "http://json-schema.org/draft-07/schema#",
2169+
"type": "object",
2170+
"strict": True,
2171+
"definitions": {
2172+
"address": {
2173+
"type": "object",
2174+
"properties": {
2175+
"street": {"type": "string"},
2176+
"city": {"type": "string"}
2177+
},
2178+
"required": ["street", "city"],
2179+
"strict": True
2180+
}
2181+
},
2182+
"properties": {
2183+
"name": {
2184+
"type": "string",
2185+
"strict": True
2186+
},
2187+
"age": {
2188+
"type": "integer"
2189+
},
2190+
"address": {
2191+
"$ref": "#/definitions/address"
2192+
},
2193+
"tags": {
2194+
"type": "array",
2195+
"items": {"type": "string"},
2196+
"strict": True
2197+
},
2198+
"contacts": {
2199+
"type": "array",
2200+
"items": {
2201+
"oneOf": [
2202+
{"type": "string"},
2203+
{
2204+
"type": "array",
2205+
"items": {
2206+
"type": "object",
2207+
"strict": True,
2208+
"properties": {
2209+
"value": {"type": "string"}
2210+
},
2211+
"required": ["value"]
2212+
}
2213+
}
2214+
],
2215+
"strict": True
2216+
}
2217+
}
2218+
}
2219+
}
2220+
original_schema = copy.deepcopy(schema)
2221+
cleaned = _remove_strict_from_schema(schema)
2222+
assert "strict" not in json.dumps(cleaned)
2223+
# schema should be unchanged, (should copy instead of mutate)
2224+
# otherwise it breaks cache keys
2225+
# https://github.com/BerriAI/litellm/issues/9692
2226+
assert cleaned != original_schema
2227+
assert schema == original_schema
2228+
2229+

tests/local_testing/test_unit_test_caching.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,26 @@ def test_generate_streaming_content():
251251
assert chunk_count > 1
252252

253253
print(f"Number of chunks: {chunk_count}")
254+
255+
def test_caching_stable_with_mutation():
256+
"""
257+
Test that caching is stable with mutation during a request. This is to circumvent the cache miss when a provider
258+
implementation mutates an argument of the original request (e.g. to normalize it for the specific provider).
259+
Otherwise the response is stored under a different key than the cache is checked with.
260+
"""
261+
litellm.cache = Cache()
262+
kwargs = {
263+
"model": "gpt-3.5-turbo",
264+
"messages": [
265+
{"role": "system", "content": "You are a helpful assistant."},
266+
{"role": "user", "content": "Hello, world!"},
267+
],
268+
"temperature": 0.7,
269+
"litellm_params": {}, # litellm_params must be set for this to work, otherwise it's not spread into the kwargs
270+
}
271+
cache_key = litellm.cache.get_cache_key(**kwargs)
272+
273+
# mutate kwargs
274+
kwargs["messages"][0]["role"] = "developer"
275+
cache_key_2 = litellm.cache.get_cache_key(**kwargs)
276+
assert cache_key == cache_key_2

0 commit comments

Comments
 (0)