diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 41b381e..434ae6b 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -560,7 +560,7 @@ def create_scores(self, scores: List[ScoreDict]): variables[f"{k}_{id}"] = v def process_response(response): - return [Score.from_dict(x) for x in response["data"].values()] + return [x for x in response["data"].values()] return self.gql_helper(query, "create scores", variables, process_response) diff --git a/literalai/client.py b/literalai/client.py index d47a4bf..8ecf11d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -30,7 +30,7 @@ class BaseLiteralClient: def __init__( self, - batch_size: int = 1, + batch_size: int = 5, is_async: bool = False, api_key: Optional[str] = None, url: Optional[str] = None, @@ -274,7 +274,7 @@ class LiteralClient(BaseLiteralClient): def __init__( self, - batch_size: int = 1, + batch_size: int = 5, api_key: Optional[str] = None, url: Optional[str] = None, environment: Optional[Environment] = None, @@ -298,7 +298,7 @@ class AsyncLiteralClient(BaseLiteralClient): def __init__( self, - batch_size: int = 1, + batch_size: int = 5, api_key: Optional[str] = None, url: Optional[str] = None, environment: Optional[Environment] = None, diff --git a/literalai/event_processor.py b/literalai/event_processor.py index ed512da..704872c 100644 --- a/literalai/event_processor.py +++ b/literalai/event_processor.py @@ -38,15 +38,21 @@ def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = Fals target=self._process_events, daemon=True ) self.disabled = disabled + self.processing_counter = 0 + self.counter_lock = threading.Lock() if not self.disabled: self.processing_thread.start() self.stop_event = threading.Event() def add_event(self, event: "StepDict"): + with self.counter_lock: + self.processing_counter += 1 self.event_queue.put(event) async def a_add_events(self, event: "StepDict"): + with self.counter_lock: + self.processing_counter += 1 await to_thread(self.event_queue.put, event) def _process_events(self): @@ -62,7 +68,7 @@ def _process_events(self): # No more events at the moment, proceed with processing what's in the batch pass - # Process the batch if any events are present - in a separate thread + # Process the batch if any events are present if batch: self._process_batch(batch) @@ -83,6 +89,8 @@ def _process_batch(self, batch: List): while not self._try_process_batch(batch) and retries < 1: retries += 1 time.sleep(DEFAULT_SLEEP_TIME) + with self.counter_lock: + self.processing_counter -= len(batch) def flush_and_stop(self): self.stop_event.set() @@ -90,12 +98,16 @@ def flush_and_stop(self): self.processing_thread.join() async def aflush(self): - while not self.event_queue.empty(): + while not self.event_queue.empty() or self._is_processing(): await asyncio.sleep(DEFAULT_SLEEP_TIME) def flush(self): - while not self.event_queue.empty(): + while not self.event_queue.empty() or self._is_processing(): time.sleep(DEFAULT_SLEEP_TIME) + def _is_processing(self): + with self.counter_lock: + return self.processing_counter > 0 + def __del__(self): self.flush_and_stop() diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 5780f71..68890e2 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -32,7 +32,7 @@ def broken_client(self): api_key = os.getenv("LITERAL_API_KEY", None) assert url is not None and api_key is not None, "Missing environment variables" - client = LiteralClient(batch_size=1, url=url, api_key=api_key) + client = LiteralClient(batch_size=5, url=url, api_key=api_key) yield client client.event_processor.flush_and_stop() @@ -42,7 +42,7 @@ def client(self): api_key = os.getenv("LITERAL_API_KEY", None) assert url is not None and api_key is not None, "Missing environment variables" - client = LiteralClient(batch_size=1, url=url, api_key=api_key) + client = LiteralClient(batch_size=5, url=url, api_key=api_key) yield client client.event_processor.flush_and_stop() @@ -53,7 +53,7 @@ def staging_client(self): assert url is not None and api_key is not None, "Missing environment variables" client = LiteralClient( - batch_size=1, url=url, api_key=api_key, environment="staging" + batch_size=5, url=url, api_key=api_key, environment="staging" ) yield client client.event_processor.flush_and_stop() @@ -64,7 +64,7 @@ def async_client(self): api_key = os.getenv("LITERAL_API_KEY", None) assert url is not None and api_key is not None, "Missing environment variables" - async_client = AsyncLiteralClient(batch_size=1, url=url, api_key=api_key) + async_client = AsyncLiteralClient(batch_size=5, url=url, api_key=api_key) yield async_client async_client.event_processor.flush_and_stop() diff --git a/tests/e2e/test_mistralai.py b/tests/e2e/test_mistralai.py index b858bcb..a61130e 100644 --- a/tests/e2e/test_mistralai.py +++ b/tests/e2e/test_mistralai.py @@ -32,7 +32,7 @@ def client(self): api_key = os.getenv("LITERAL_API_KEY", None) assert url is not None and api_key is not None, "Missing environment variables" - client = LiteralClient(batch_size=1, url=url, api_key=api_key) + client = LiteralClient(batch_size=5, url=url, api_key=api_key) client.instrument_mistralai() return client diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index c0aec52..0c04c45 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -33,7 +33,7 @@ def client(self): api_key = os.getenv("LITERAL_API_KEY", None) assert url is not None and api_key is not None, "Missing environment variables" - client = LiteralClient(batch_size=1, url=url, api_key=api_key) + client = LiteralClient(batch_size=5, url=url, api_key=api_key) client.instrument_openai() return client