Skip to content

Commit b6805c8

Browse files
Costs - decimals instead of floats
1 parent 14357e9 commit b6805c8

File tree

5 files changed

+58
-44
lines changed

5 files changed

+58
-44
lines changed

tests/test_all.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import os
10+
from decimal import Decimal
1011
from pathlib import Path
1112
from typing import Literal
1213
import warnings
@@ -245,11 +246,13 @@ def test_attr_models(self):
245246
with pytest.raises(ValueError):
246247
LLMPricing(input_per_1m_tokens=0.1)
247248
# LLM cost model
248-
_LLMCost(input=0.01, output=0.02, total=0.03)
249+
_LLMCost(input=Decimal("0.01"), output=Decimal("0.02"), total=Decimal("0.03"))
249250
with pytest.raises(ValueError):
250-
_LLMCost(input=-0.01, output=0.02, total=0.03)
251+
_LLMCost(
252+
input=-Decimal("0.01"), output=Decimal("0.02"), total=Decimal("0.03")
253+
)
251254
with pytest.raises(ValueError):
252-
_LLMCost(input=-0.01)
255+
_LLMCost(input=-Decimal("0.001"))
253256

254257
def test_init_instance_bases(self):
255258
"""
@@ -2481,6 +2484,6 @@ def test_total_cost_and_reset(self):
24812484
assert usage_dict.usage.input == 0
24822485
assert usage_dict.usage.output == 0
24832486
for cost_dict in self.llm_group.get_cost() + self.llm_extractor_text.get_cost():
2484-
assert cost_dict.cost.input == 0
2485-
assert cost_dict.cost.output == 0
2486-
assert cost_dict.cost.total == 0
2487+
assert cost_dict.cost.input == Decimal("0.00000")
2488+
assert cost_dict.cost.output == Decimal("0.00000")
2489+
assert cost_dict.cost.total == Decimal("0.00000")

tests/utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import os
1111
import re
1212
from copy import deepcopy
13+
from decimal import Decimal, ROUND_HALF_UP
1314
from pathlib import Path
1415
import time
1516
from typing import Callable, Literal
@@ -227,9 +228,9 @@ def check_cost(llm: DocumentLLMGroup | DocumentLLM) -> None:
227228
assert hasattr(cost_item, "is_fallback")
228229
assert hasattr(cost_item, "cost") and cost_item.cost
229230
assert isinstance(cost_item.cost, _LLMCost)
230-
if cost_item.cost.input > 0:
231+
if cost_item.cost.input > Decimal("0.00000"):
231232
cost_input_updated = True
232-
if cost_item.cost.output > 0:
233+
if cost_item.cost.output > Decimal("0.00000"):
233234
cost_output_updated = True
234235
assert cost_input_updated and cost_output_updated
235236

@@ -240,22 +241,25 @@ def output_test_costs(self) -> None:
240241
:return: None
241242
"""
242243

243-
def get_cost_as_float(llm):
244-
return round(sum([i.cost.total for i in llm.get_cost()]), 6)
244+
def get_cost_as_decimal(llm):
245+
# Sum all the total costs from each cost item starting from a Decimal zero.
246+
total = sum((i.cost.total for i in llm.get_cost()), Decimal("0.00000"))
247+
return total.quantize(Decimal("0.00001"), rounding=ROUND_HALF_UP)
245248

246-
total_cost_llm_group = get_cost_as_float(self.llm_group)
247-
total_cost_llm = get_cost_as_float(self.llm_extractor_text)
248-
total_cost_llm_with_fallback = get_cost_as_float(self.llm_with_fallback)
249+
total_cost_llm_group = get_cost_as_decimal(self.llm_group)
250+
total_cost_llm = get_cost_as_decimal(self.llm_extractor_text)
251+
total_cost_llm_with_fallback = get_cost_as_decimal(self.llm_with_fallback)
249252
print("Cost of running tests (LLM 0 - group): ", total_cost_llm_group)
250253
print("Cost of running tests (LLM 1 - individual): ", total_cost_llm)
251254
print(
252255
"Cost of running tests (LLM with fallback): ", total_cost_llm_with_fallback
253256
)
257+
total_cost = (
258+
total_cost_llm_group + total_cost_llm + total_cost_llm_with_fallback
259+
)
254260
print(
255261
"Total cost running tests: ",
256-
round(
257-
total_cost_llm_group + total_cost_llm + total_cost_llm_with_fallback, 6
258-
),
262+
total_cost.quantize(Decimal("0.00001"), rounding=ROUND_HALF_UP),
259263
)
260264

261265
@staticmethod

xdoc/internal/data_models.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66

77
import time
8+
from decimal import Decimal
89
from typing import Any, Optional
910

1011
from pydantic import BaseModel, Field, ConfigDict
11-
from pydantic import StrictFloat, StrictStr, StrictInt, StrictBool
12+
from pydantic import StrictStr, StrictInt, StrictBool
1213

13-
from xdoc.internal.typings.aliases import NonEmptyStr, LLMRoleAny
14+
from xdoc.internal.typings.aliases import NonEmptyStr, LLMRoleAny, DefaultDecimalField
1415

1516

1617
class _LLMCall(BaseModel):
@@ -132,25 +133,16 @@ class _LLMCost(BaseModel):
132133
processing inputs, outputs, and the total processing cost.
133134
134135
:ivar input: Cost associated with processing the input.
135-
:type input: float
136+
:type input: Decimal
136137
:ivar output: Cost associated with generating the output.
137-
:type output: float
138+
:type output: Decimal
138139
:ivar total: Total cost combining both input and output processing.
139-
:type total: float
140+
:type total: Decimal
140141
"""
141142

142-
input: StrictFloat = Field(
143-
default=0.0,
144-
ge=0,
145-
)
146-
output: StrictFloat = Field(
147-
default=0.0,
148-
ge=0,
149-
)
150-
total: StrictFloat = Field(
151-
default=0.0,
152-
ge=0,
153-
)
143+
input: Decimal = DefaultDecimalField
144+
output: Decimal = DefaultDecimalField
145+
total: Decimal = DefaultDecimalField
154146

155147
model_config = ConfigDict(extra="forbid", validate_assignment=True)
156148

xdoc/internal/typings/aliases.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
Internal module containing type aliases used throughout the framework.
55
"""
66

7+
from decimal import Decimal
78
from typing import Annotated, Literal, Callable, Coroutine, Any
89

9-
from pydantic import StrictStr, StringConstraints
10+
from pydantic import StrictStr, StringConstraints, Field
1011

1112

1213
NonEmptyStr = Annotated[
@@ -44,3 +45,7 @@
4445
AsyncCalsAndKwargs = list[
4546
tuple[Callable[..., Coroutine[Any, Any, Any]], dict[str, Any]]
4647
]
48+
49+
DefaultDecimalField = Field(
50+
default_factory=lambda: Decimal("0.00000"), ge=Decimal("0.00000")
51+
)

xdoc/public/llms.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import asyncio
1010
import warnings
11+
from decimal import Decimal, ROUND_HALF_UP
1112
from typing import TYPE_CHECKING, Optional, Any, Self
1213

1314
from aiolimiter import AsyncLimiter
@@ -714,20 +715,29 @@ def _increment_cost(self, usage: _LLMUsage) -> None:
714715
"""
715716

716717
if self.pricing_details:
717-
cost_input = (
718-
usage.input / 1_000_000
719-
) * self.pricing_details.input_per_1m_tokens
720-
cost_output = (
721-
usage.output / 1_000_000
722-
) * self.pricing_details.output_per_1m_tokens
718+
mil_dec = Decimal("1000000")
719+
cost_input = (Decimal(str(usage.input)) / mil_dec) * Decimal(
720+
str(self.pricing_details.input_per_1m_tokens)
721+
)
722+
cost_output = (Decimal(str(usage.output)) / mil_dec) * Decimal(
723+
str(self.pricing_details.output_per_1m_tokens)
724+
)
723725
cost_total = cost_input + cost_output
726+
724727
self._cost.input += cost_input
725728
self._cost.output += cost_output
726729
self._cost.total += cost_total
727-
round_cost = lambda x: round(x, 6)
728-
self._cost.input = round_cost(self._cost.input)
729-
self._cost.output = round_cost(self._cost.output)
730-
self._cost.total = round_cost(self._cost.total)
730+
731+
round_dec = Decimal("0.00001")
732+
self._cost.input = self._cost.input.quantize(
733+
round_dec, rounding=ROUND_HALF_UP
734+
)
735+
self._cost.output = self._cost.output.quantize(
736+
round_dec, rounding=ROUND_HALF_UP
737+
)
738+
self._cost.total = self._cost.total.quantize(
739+
round_dec, rounding=ROUND_HALF_UP
740+
)
731741

732742
def _update_usage_and_cost(self, result: tuple[Any, _LLMUsage]) -> None:
733743
"""

0 commit comments

Comments
 (0)