Skip to content

Commit 0dbd663

Browse files
fix(cost_calculator.py): handle custom pricing at deployment level fo… (BerriAI#9855)
* fix(cost_calculator.py): handle custom pricing at deployment level for router * test: add unit tests * fix(router.py): show custom pricing on UI check correct model str * fix: fix linting error * docs(custom_pricing.md): clarify custom pricing for proxy Fixes BerriAI#8573 (comment) * test: update code qa test * fix: cleanup traceback * fix: handle litellm param custom pricing * test: update test * fix(cost_calculator.py): add router model id to list of potential model names * fix(cost_calculator.py): fix router model id check * fix: router.py - maintain older model registry approach * fix: fix ruff check * fix(router.py): router get deployment info add custom values to mapped dict * test: update test * fix(utils.py): update only if value is non-null * test: add unit test
1 parent 0c5b4aa commit 0dbd663

File tree

16 files changed

+193
-37
lines changed

16 files changed

+193
-37
lines changed

docs/my-website/docs/proxy/custom_pricing.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ model_list:
2626
- model_name: sagemaker-completion-model
2727
litellm_params:
2828
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
29+
model_info:
2930
input_cost_per_second: 0.000420
3031
- model_name: sagemaker-embedding-model
3132
litellm_params:
3233
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
34+
model_info:
3335
input_cost_per_second: 0.000420
3436
```
3537
@@ -55,11 +57,33 @@ model_list:
5557
api_key: os.environ/AZURE_API_KEY
5658
api_base: os.environ/AZURE_API_BASE
5759
api_version: os.envrion/AZURE_API_VERSION
60+
model_info:
5861
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
5962
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
6063
```
6164
62-
### Debugging
65+
## Override Model Cost Map
66+
67+
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
68+
69+
Just add a `model_info` key to your model in the config, and override the desired keys.
70+
71+
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
72+
73+
```yaml
74+
model_list:
75+
- model_name: "prod/claude-3-5-sonnet-20241022"
76+
litellm_params:
77+
model: "anthropic/claude-3-5-sonnet-20241022"
78+
api_key: os.environ/ANTHROPIC_PROD_API_KEY
79+
model_info:
80+
input_cost_per_token: 0.000006
81+
output_cost_per_token: 0.00003
82+
cache_creation_input_token_cost: 0.0000075
83+
cache_read_input_token_cost: 0.0000006
84+
```
85+
86+
## Debugging
6387

6488
If you're custom pricing is not being used or you're seeing errors, please check the following:
6589

litellm/cost_calculator.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def _select_model_name_for_cost_calc(
403403
base_model: Optional[str] = None,
404404
custom_pricing: Optional[bool] = None,
405405
custom_llm_provider: Optional[str] = None,
406+
router_model_id: Optional[str] = None,
406407
) -> Optional[str]:
407408
"""
408409
1. If custom pricing is true, return received model name
@@ -417,19 +418,23 @@ def _select_model_name_for_cost_calc(
417418
model=model, custom_llm_provider=custom_llm_provider
418419
)
419420

420-
if custom_pricing is True:
421-
return_model = model
422-
423-
if base_model is not None:
424-
return_model = base_model
425-
426421
completion_response_model: Optional[str] = None
427422
if completion_response is not None:
428423
if isinstance(completion_response, BaseModel):
429424
completion_response_model = getattr(completion_response, "model", None)
430425
elif isinstance(completion_response, dict):
431426
completion_response_model = completion_response.get("model", None)
432427
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
428+
429+
if custom_pricing is True:
430+
if router_model_id is not None and router_model_id in litellm.model_cost:
431+
return_model = router_model_id
432+
else:
433+
return_model = model
434+
435+
if base_model is not None:
436+
return_model = base_model
437+
433438
if completion_response_model is None and hidden_params is not None:
434439
if (
435440
hidden_params.get("model", None) is not None
@@ -559,6 +564,7 @@ def completion_cost( # noqa: PLR0915
559564
base_model: Optional[str] = None,
560565
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
561566
litellm_model_name: Optional[str] = None,
567+
router_model_id: Optional[str] = None,
562568
) -> float:
563569
"""
564570
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@@ -617,12 +623,12 @@ def completion_cost( # noqa: PLR0915
617623
custom_llm_provider=custom_llm_provider,
618624
custom_pricing=custom_pricing,
619625
base_model=base_model,
626+
router_model_id=router_model_id,
620627
)
621628

622629
potential_model_names = [selected_model]
623630
if model is not None:
624631
potential_model_names.append(model)
625-
626632
for idx, model in enumerate(potential_model_names):
627633
try:
628634
verbose_logger.info(
@@ -943,6 +949,7 @@ def response_cost_calculator(
943949
prompt: str = "",
944950
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
945951
litellm_model_name: Optional[str] = None,
952+
router_model_id: Optional[str] = None,
946953
) -> float:
947954
"""
948955
Returns
@@ -973,6 +980,8 @@ def response_cost_calculator(
973980
base_model=base_model,
974981
prompt=prompt,
975982
standard_built_in_tools_params=standard_built_in_tools_params,
983+
litellm_model_name=litellm_model_name,
984+
router_model_id=router_model_id,
976985
)
977986
return response_cost
978987
except Exception as e:

litellm/litellm_core_utils/get_supported_openai_params.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal, Optional
22

33
import litellm
4-
from litellm._logging import verbose_logger
54
from litellm.exceptions import BadRequestError
65
from litellm.types.utils import LlmProviders, LlmProvidersSet
76

@@ -43,9 +42,6 @@ def get_supported_openai_params( # noqa: PLR0915
4342
provider_config = None
4443

4544
if provider_config and request_type == "chat_completion":
46-
verbose_logger.info(
47-
f"using provider_config: {provider_config} for checking supported params"
48-
)
4945
return provider_config.get_supported_openai_params(model=model)
5046

5147
if custom_llm_provider == "bedrock":

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,6 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR
622622
] = RawRequestTypedDict(
623623
error=str(e),
624624
)
625-
traceback.print_exc()
626625
_metadata[
627626
"raw_request"
628627
] = "Unable to Log \
@@ -906,6 +905,7 @@ def _response_cost_calculator(
906905
],
907906
cache_hit: Optional[bool] = None,
908907
litellm_model_name: Optional[str] = None,
908+
router_model_id: Optional[str] = None,
909909
) -> Optional[float]:
910910
"""
911911
Calculate response cost using result + logging object variables.
@@ -944,6 +944,7 @@ def _response_cost_calculator(
944944
"custom_pricing": custom_pricing,
945945
"prompt": prompt,
946946
"standard_built_in_tools_params": self.standard_built_in_tools_params,
947+
"router_model_id": router_model_id,
947948
}
948949
except Exception as e: # error creating kwargs for cost calculation
949950
debug_info = StandardLoggingModelCostFailureDebugInformation(

litellm/litellm_core_utils/llm_response_utils/response_metadata.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,16 @@ def set_hidden_params(
3636
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
3737
) -> None:
3838
"""Set hidden parameters on the response"""
39+
40+
## ADD OTHER HIDDEN PARAMS
41+
model_id = kwargs.get("model_info", {}).get("id", None)
3942
new_params = {
4043
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
41-
"model_id": kwargs.get("model_info", {}).get("id", None),
4244
"api_base": get_api_base(model=model or "", optional_params=kwargs),
43-
"response_cost": logging_obj._response_cost_calculator(result=self.result),
45+
"model_id": model_id,
46+
"response_cost": logging_obj._response_cost_calculator(
47+
result=self.result, litellm_model_name=model, router_model_id=model_id
48+
),
4449
"additional_headers": process_response_headers(
4550
self._get_value_from_hidden_params("additional_headers") or {}
4651
),

litellm/litellm_core_utils/prompt_templates/factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
import json
33
import re
4-
import traceback
54
import uuid
65
import xml.etree.ElementTree as ET
76
from enum import Enum
@@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
748747
data=base64_data,
749748
)
750749
except Exception as e:
751-
traceback.print_exc()
752750
if "Error: Unable to fetch image from URL" in str(e):
753751
raise e
754752
raise Exception(

litellm/proxy/caching_routes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ async def cache_ping():
100100
except Exception as e:
101101
import traceback
102102

103-
traceback.print_exc()
104103
error_message = {
105104
"message": f"Service Unhealthy ({str(e)})",
106105
"litellm_cache_params": safe_dumps(litellm_cache_params),

litellm/proxy/management_endpoints/organization_endpoints.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,9 +816,6 @@ async def add_member_to_organization(
816816
return user_object, organization_membership
817817

818818
except Exception as e:
819-
import traceback
820-
821-
traceback.print_exc()
822819
raise ValueError(
823820
f"Error adding member={member} to organization={organization_id}: {e}"
824821
)

litellm/router.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
AllowedFailsPolicy,
117117
AssistantsTypedDict,
118118
CredentialLiteLLMParams,
119+
CustomPricingLiteLLMParams,
119120
CustomRoutingStrategyBase,
120121
Deployment,
121122
DeploymentTypedDict,
@@ -132,6 +133,7 @@
132133
)
133134
from litellm.types.services import ServiceTypes
134135
from litellm.types.utils import GenericBudgetConfigType
136+
from litellm.types.utils import ModelInfo
135137
from litellm.types.utils import ModelInfo as ModelMapInfo
136138
from litellm.types.utils import StandardLoggingPayload
137139
from litellm.utils import (
@@ -3324,7 +3326,6 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915
33243326

33253327
return response
33263328
except Exception as new_exception:
3327-
traceback.print_exc()
33283329
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
33293330
verbose_router_logger.error(
33303331
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
@@ -4301,7 +4302,20 @@ def _create_deployment(
43014302
model_info=_model_info,
43024303
)
43034304

4305+
for field in CustomPricingLiteLLMParams.model_fields.keys():
4306+
if deployment.litellm_params.get(field) is not None:
4307+
_model_info[field] = deployment.litellm_params[field]
4308+
43044309
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
4310+
model_id = deployment.model_info.id
4311+
if model_id is not None:
4312+
litellm.register_model(
4313+
model_cost={
4314+
model_id: _model_info,
4315+
}
4316+
)
4317+
4318+
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
43054319
_model_name = deployment.litellm_params.model
43064320
if deployment.litellm_params.custom_llm_provider is not None:
43074321
_model_name = (
@@ -4802,6 +4816,42 @@ def get_model_group(self, id: str) -> Optional[List]:
48024816
model_name = model_info["model_name"]
48034817
return self.get_model_list(model_name=model_name)
48044818

4819+
def get_deployment_model_info(
4820+
self, model_id: str, model_name: str
4821+
) -> Optional[ModelInfo]:
4822+
"""
4823+
For a given model id, return the model info
4824+
4825+
1. Check if model_id is in model info
4826+
2. If not, check if litellm model name is in model info
4827+
3. If not, return None
4828+
"""
4829+
from litellm.utils import _update_dictionary
4830+
4831+
model_info: Optional[ModelInfo] = None
4832+
litellm_model_name_model_info: Optional[ModelInfo] = None
4833+
4834+
try:
4835+
model_info = litellm.get_model_info(model=model_id)
4836+
except Exception:
4837+
pass
4838+
4839+
try:
4840+
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
4841+
except Exception:
4842+
pass
4843+
4844+
if model_info is not None and litellm_model_name_model_info is not None:
4845+
model_info = cast(
4846+
ModelInfo,
4847+
_update_dictionary(
4848+
cast(dict, litellm_model_name_model_info).copy(),
4849+
cast(dict, model_info),
4850+
),
4851+
)
4852+
4853+
return model_info
4854+
48054855
def _set_model_group_info( # noqa: PLR0915
48064856
self, model_group: str, user_facing_model_group_name: str
48074857
) -> Optional[ModelGroupInfo]:
@@ -4860,9 +4910,16 @@ def _set_model_group_info( # noqa: PLR0915
48604910

48614911
# get model info
48624912
try:
4863-
model_info = litellm.get_model_info(model=litellm_params.model)
4913+
model_id = model.get("model_info", {}).get("id", None)
4914+
if model_id is not None:
4915+
model_info = self.get_deployment_model_info(
4916+
model_id=model_id, model_name=litellm_params.model
4917+
)
4918+
else:
4919+
model_info = None
48644920
except Exception:
48654921
model_info = None
4922+
48664923
# get llm provider
48674924
litellm_model, llm_provider = "", ""
48684925
try:

litellm/types/router.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,15 @@ class CredentialLiteLLMParams(BaseModel):
162162
watsonx_region_name: Optional[str] = None
163163

164164

165-
class GenericLiteLLMParams(CredentialLiteLLMParams):
165+
class CustomPricingLiteLLMParams(BaseModel):
166+
## CUSTOM PRICING ##
167+
input_cost_per_token: Optional[float] = None
168+
output_cost_per_token: Optional[float] = None
169+
input_cost_per_second: Optional[float] = None
170+
output_cost_per_second: Optional[float] = None
171+
172+
173+
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
166174
"""
167175
LiteLLM Params without 'model' arg (used across completion / assistants api)
168176
"""
@@ -184,12 +192,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams):
184192
## LOGGING PARAMS ##
185193
litellm_trace_id: Optional[str] = None
186194

187-
## CUSTOM PRICING ##
188-
input_cost_per_token: Optional[float] = None
189-
output_cost_per_token: Optional[float] = None
190-
input_cost_per_second: Optional[float] = None
191-
output_cost_per_second: Optional[float] = None
192-
193195
max_file_size_mb: Optional[float] = None
194196

195197
# Deployment budgets

litellm/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2245,7 +2245,8 @@ def supports_embedding_image_input(
22452245
####### HELPER FUNCTIONS ################
22462246
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
22472247
for k, v in new_dict.items():
2248-
existing_dict[k] = v
2248+
if v is not None:
2249+
existing_dict[k] = v
22492250

22502251
return existing_dict
22512252

tests/code_coverage_tests/router_code_coverage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_all_functions_called_in_tests(base_dir):
3131
specifically in files containing the word 'router'.
3232
"""
3333
called_functions = set()
34-
test_dirs = ["local_testing", "router_unit_tests"]
34+
test_dirs = ["local_testing", "router_unit_tests", "litellm"]
3535

3636
for test_dir in test_dirs:
3737
dir_path = os.path.join(base_dir, test_dir)

0 commit comments

Comments
 (0)