1010import os
1111import re
1212from copy import deepcopy
13+ from decimal import Decimal , ROUND_HALF_UP
1314from pathlib import Path
1415import time
1516from typing import Callable , Literal
@@ -227,9 +228,9 @@ def check_cost(llm: DocumentLLMGroup | DocumentLLM) -> None:
227228 assert hasattr (cost_item , "is_fallback" )
228229 assert hasattr (cost_item , "cost" ) and cost_item .cost
229230 assert isinstance (cost_item .cost , _LLMCost )
230- if cost_item .cost .input > 0 :
231+ if cost_item .cost .input > Decimal ( "0.00000" ) :
231232 cost_input_updated = True
232- if cost_item .cost .output > 0 :
233+ if cost_item .cost .output > Decimal ( "0.00000" ) :
233234 cost_output_updated = True
234235 assert cost_input_updated and cost_output_updated
235236
@@ -240,22 +241,25 @@ def output_test_costs(self) -> None:
240241 :return: None
241242 """
242243
243- def get_cost_as_float (llm ):
244- return round (sum ([i .cost .total for i in llm .get_cost ()]), 6 )
244+ def get_cost_as_decimal (llm ):
245+ # Sum all the total costs from each cost item starting from a Decimal zero.
246+ total = sum ((i .cost .total for i in llm .get_cost ()), Decimal ("0.00000" ))
247+ return total .quantize (Decimal ("0.00001" ), rounding = ROUND_HALF_UP )
245248
246- total_cost_llm_group = get_cost_as_float (self .llm_group )
247- total_cost_llm = get_cost_as_float (self .llm_extractor_text )
248- total_cost_llm_with_fallback = get_cost_as_float (self .llm_with_fallback )
249+ total_cost_llm_group = get_cost_as_decimal (self .llm_group )
250+ total_cost_llm = get_cost_as_decimal (self .llm_extractor_text )
251+ total_cost_llm_with_fallback = get_cost_as_decimal (self .llm_with_fallback )
249252 print ("Cost of running tests (LLM 0 - group): " , total_cost_llm_group )
250253 print ("Cost of running tests (LLM 1 - individual): " , total_cost_llm )
251254 print (
252255 "Cost of running tests (LLM with fallback): " , total_cost_llm_with_fallback
253256 )
257+ total_cost = (
258+ total_cost_llm_group + total_cost_llm + total_cost_llm_with_fallback
259+ )
254260 print (
255261 "Total cost running tests: " ,
256- round (
257- total_cost_llm_group + total_cost_llm + total_cost_llm_with_fallback , 6
258- ),
262+ total_cost .quantize (Decimal ("0.00001" ), rounding = ROUND_HALF_UP ),
259263 )
260264
261265 @staticmethod
0 commit comments