diff --git a/literalai/client.py b/literalai/client.py index 3d06ea7..4f33a41 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -2,7 +2,7 @@ import json import os from contextlib import redirect_stdout -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from traceloop.sdk import Traceloop from typing_extensions import deprecated @@ -29,6 +29,7 @@ MessageStepType, Step, StepContextManager, + StepDict, TrueStepType, step_decorator, ) @@ -94,15 +95,55 @@ def __init__( def to_sync(self) -> "LiteralClient": if isinstance(self.api, AsyncLiteralAPI): - return LiteralClient( + sync_client = LiteralClient( batch_size=self.event_processor.batch_size, api_key=self.api.api_key, url=self.api.url, disabled=self.disabled, ) + if self.event_processor.preprocess_steps_function: + sync_client.event_processor.set_preprocess_steps_function( + self.event_processor.preprocess_steps_function + ) + + return sync_client else: return self # type: ignore + def set_preprocess_steps_function( + self, + preprocess_steps_function: Optional[ + Callable[[List["StepDict"]], List["StepDict"]] + ], + ) -> None: + """ + Set a function that will preprocess steps before sending them to the API. + This can be used for tasks like PII removal or other data transformations. + + The preprocess function should: + - Accept a list of StepDict objects as input + - Return a list of modified StepDict objects + - Be thread-safe as it will be called from a background thread + - Handle exceptions internally to avoid disrupting the event processing + + Example: + ```python + def remove_pii(steps): + # Process steps to remove PII data + for step in steps: + if step.get("content"): + step["content"] = my_pii_removal_function(step["content"]) + return steps + + client.set_preprocess_steps_function(remove_pii) + ``` + + Args: + preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]): + Function that takes a list of steps and returns a processed list + """ + self.event_processor.set_preprocess_steps_function(preprocess_steps_function) + @deprecated("Use Literal.initialize instead") def instrument_openai(self): """ diff --git a/literalai/event_processor.py b/literalai/event_processor.py index aae1f61..b8a92ec 100644 --- a/literalai/event_processor.py +++ b/literalai/event_processor.py @@ -4,7 +4,7 @@ import threading import time import traceback -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Callable, List, Optional logger = logging.getLogger(__name__) @@ -31,7 +31,15 @@ class EventProcessor: batch: List["StepDict"] batch_timeout: float = 5.0 - def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = False): + def __init__( + self, + api: "LiteralAPI", + batch_size: int = 1, + disabled: bool = False, + preprocess_steps_function: Optional[ + Callable[[List["StepDict"]], List["StepDict"]] + ] = None, + ): self.stop_event = threading.Event() self.batch_size = batch_size self.api = api @@ -40,6 +48,7 @@ def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = Fals self.processing_counter = 0 self.counter_lock = threading.Lock() self.last_batch_time = time.time() + self.preprocess_steps_function = preprocess_steps_function self.processing_thread = threading.Thread( target=self._process_events, daemon=True ) @@ -56,6 +65,22 @@ async def a_add_events(self, event: "StepDict"): self.processing_counter += 1 await to_thread(self.event_queue.put, event) + def set_preprocess_steps_function( + self, + preprocess_steps_function: Optional[ + Callable[[List["StepDict"]], List["StepDict"]] + ], + ): + """ + Set a function that will preprocess steps before sending them to the API. + The function should take a list of StepDict objects and return a list of processed StepDict objects. + This can be used for tasks like PII removal. + + Args: + preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]): The preprocessing function + """ + self.preprocess_steps_function = preprocess_steps_function + def _process_events(self): while True: batch = [] @@ -83,6 +108,24 @@ def _process_events(self): def _try_process_batch(self, batch: List): try: + # Apply preprocessing function if it exists + if self.preprocess_steps_function is not None: + try: + processed_batch = self.preprocess_steps_function(batch) + # Only use the processed batch if it's valid + if processed_batch is not None and isinstance( + processed_batch, list + ): + batch = processed_batch + else: + logger.warning( + "Preprocess function returned invalid result, using original batch" + ) + except Exception as e: + logger.error(f"Error in preprocess function: {str(e)}") + logger.error(traceback.format_exc()) + # Continue with the original batch + return self.api.send_steps(batch) except Exception: logger.error(f"Failed to send steps: {traceback.format_exc()}") diff --git a/literalai/version.py b/literalai/version.py index 394520a..ae73625 100644 --- a/literalai/version.py +++ b/literalai/version.py @@ -1 +1 @@ -__version__ = "0.1.202" +__version__ = "0.1.3" diff --git a/setup.py b/setup.py index 50d3fea..982118d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="literalai", - version="0.1.202", # update version in literalai/version.py + version="0.1.3", # update version in literalai/version.py description="An SDK for observability in Python applications", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 2546e48..704d513 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -801,3 +801,97 @@ async def test_environment(self, staging_client: LiteralClient): persisted_run = staging_client.api.get_step(run_id) assert persisted_run is not None assert persisted_run.environment == "staging" + + @pytest.mark.timeout(5) + async def test_pii_removal( + self, client: LiteralClient, async_client: AsyncLiteralClient + ): + """Test that PII is properly removed by the preprocess function.""" + import re + + # Define a PII removal function + def remove_pii(steps): + # Patterns for common PII + email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + phone_pattern = r"\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b" + ssn_pattern = r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b" + + for step in steps: + # Process content field if it exists + if "output" in step and step["output"]["content"]: + # Replace emails with [EMAIL REDACTED] + step["output"]["content"] = re.sub( + email_pattern, "[EMAIL REDACTED]", step["output"]["content"] + ) + + # Replace phone numbers with [PHONE REDACTED] + step["output"]["content"] = re.sub( + phone_pattern, "[PHONE REDACTED]", step["output"]["content"] + ) + + # Replace SSNs with [SSN REDACTED] + step["output"]["content"] = re.sub( + ssn_pattern, "[SSN REDACTED]", step["output"]["content"] + ) + + return steps + + # Set the PII removal function on the client + client.set_preprocess_steps_function(remove_pii) + + @client.thread + def thread_with_pii(): + thread = client.get_current_thread() + + # User message with PII + user_step = client.message( + content="My email is test@example.com and my phone is (123) 456-7890. My SSN is 123-45-6789.", + type="user_message", + metadata={"contact_info": "Call me at 987-654-3210"}, + ) + user_step_id = user_step.id + + # Assistant message with PII reference + assistant_step = client.message( + content="I'll contact you at test@example.com", type="assistant_message" + ) + assistant_step_id = assistant_step.id + + return thread.id, user_step_id, assistant_step_id + + # Run the thread + thread_id, user_step_id, assistant_step_id = thread_with_pii() + + # Wait for processing to occur + client.flush() + + # Fetch the steps and verify PII was removed + user_step = client.api.get_step(id=user_step_id) + assistant_step = client.api.get_step(id=assistant_step_id) + + assert user_step + assert assistant_step + + user_step_output = user_step.output["content"] # type: ignore + + # Check user message + assert "test@example.com" not in user_step_output + assert "(123) 456-7890" not in user_step_output + assert "123-45-6789" not in user_step_output + assert "[EMAIL REDACTED]" in user_step_output + assert "[PHONE REDACTED]" in user_step_output + assert "[SSN REDACTED]" in user_step_output + + assistant_step_output = assistant_step.output["content"] # type: ignore + + # Check assistant message + assert "test@example.com" not in assistant_step_output + assert "[EMAIL REDACTED]" in assistant_step_output + + # Clean up + client.api.delete_step(id=user_step_id) + client.api.delete_step(id=assistant_step_id) + client.api.delete_thread(id=thread_id) + + # Reset the preprocess function to avoid affecting other tests + client.set_preprocess_steps_function(None)