diff --git a/guardrails_api/__init__.py b/guardrails_api/__init__.py index 5c71f76..30499ab 100644 --- a/guardrails_api/__init__.py +++ b/guardrails_api/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0-alpha1" +__version__ = "0.1.0-alpha2" diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index a9f70d6..2030e32 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -42,6 +42,8 @@ router = APIRouter() +def guard_history_is_enabled(): + return os.environ.get("GUARD_HISTORY_ENABLED", "true").lower() == "true" @router.get("/guards") @handle_error @@ -125,7 +127,12 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): ) if not stream: - validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload) + execution = guard(num_reasks=0, **payload) + if inspect.iscoroutine(execution): + validation_outcome: ValidationOutcome = await execution + else: + validation_outcome: ValidationOutcome = execution + llm_response = guard.history.last.iterations.last.outputs.llm_response_info result = outcome_to_chat_completion( validation_outcome=validation_outcome, @@ -136,13 +143,17 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): else: async def openai_streamer(): - guard_stream = await guard(num_reasks=0, **payload) - async for result in guard_stream: - chunk = json.dumps( - outcome_to_stream_response(validation_outcome=result) - ) - yield f"data: {chunk}\n\n" - yield "\n" + try: + guard_stream = await guard(num_reasks=0, **payload) + async for result in guard_stream: + chunk = json.dumps( + outcome_to_stream_response(validation_outcome=result) + ) + yield f"data: {chunk}\n\n" + yield "\n" + except Exception as e: + yield f"data: {json.dumps({'error': {'message':str(e)}})}\n\n" + yield "\n" return StreamingResponse(openai_streamer(), media_type="text/event-stream") @@ -196,58 +207,75 @@ async def validate(guard_name: str, request: Request): raise HTTPException( status_code=400, detail="Streaming is not supported for parse calls!" ) - result: ValidationOutcome = guard.parse( + execution = guard.parse( llm_output=llm_output, num_reasks=num_reasks, prompt_params=prompt_params, llm_api=llm_api, **payload, ) + if inspect.iscoroutine(execution): + result: ValidationOutcome = await execution + else: + result: ValidationOutcome = execution else: if stream: - async def guard_streamer(): - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - for result in guard_stream: - validation_output = ValidationOutcome.from_guard_history( - guard.history.last + call = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, ) - yield validation_output, result + is_async = inspect.iscoroutine(call) + if is_async: + guard_stream = await call + async for result in guard_stream: + validation_output = ValidationOutcome.from_guard_history( + guard.history.last + ) + yield validation_output, result + else: + guard_stream = call + for result in guard_stream: + validation_output = ValidationOutcome.from_guard_history( + guard.history.last + ) + yield validation_output, result async def validate_streamer(guard_iter): - async for validation_output, result in guard_iter: - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = [ + try: + async for validation_output, result in guard_iter: + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = [ + json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) + for x in guard.error_spans_in_output() + ] + yield json.dumps(fragment_dict) + "\n" + + call = guard.history.last + final_validation_output = ValidationOutcome( + callId=call.id, + validation_passed=result.validation_passed, + validated_output=result.validated_output, + history=guard.history, + raw_llm_output=result.raw_llm_output, + ) + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = [ json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) for x in guard.error_spans_in_output() ] - yield json.dumps(fragment_dict) + "\n" - - call = guard.history.last - final_validation_output = ValidationOutcome( - callId=call.id, - validation_passed=result.validation_passed, - validated_output=result.validated_output, - history=guard.history, - raw_llm_output=result.raw_llm_output, - ) - final_output_dict = final_validation_output.to_dict() - final_output_dict["error_spans"] = [ - json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) - for x in guard.error_spans_in_output() - ] - yield json.dumps(final_output_dict) + "\n" - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{final_validation_output.call_id}" - await cache_client.set(cache_key, serialized_history, 300) + yield json.dumps(final_output_dict) + "\n" + except Exception as e: + yield json.dumps({"error": {"message": str(e)}}) + "\n" + + if guard_history_is_enabled(): + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{final_validation_output.call_id}" + await cache_client.set(cache_key, serialized_history, 300) return StreamingResponse( validate_streamer(guard_streamer()), media_type="application/json" @@ -260,15 +288,14 @@ async def validate_streamer(guard_iter): *args, **payload, ) - if inspect.iscoroutine(execution): result: ValidationOutcome = await execution else: result: ValidationOutcome = execution - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{result.call_id}" - await cache_client.set(cache_key, serialized_history, 300) + if guard_history_is_enabled(): + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{result.call_id}" + await cache_client.set(cache_key, serialized_history, 300) return result.to_dict() diff --git a/tests/api/test_guards.py b/tests/api/test_guards.py index 416baf0..ef58c7c 100644 --- a/tests/api/test_guards.py +++ b/tests/api/test_guards.py @@ -14,9 +14,6 @@ from tests.mocks.mock_guard_client import MockGuardStruct from guardrails_api.api.guards import router as guards_router - -import asyncio - # TODO: Should we mock this somehow? # Right now it's just empty, but it technically does a file read register_config() @@ -347,9 +344,8 @@ def test_openai_v1_chat_completions__call(mocker): ) mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - future = asyncio.Future() - future.set_result(mock_outcome) - mock___call__.return_value = future + + mock___call__.return_value = mock_outcome mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") mock_from_dict.return_value = mock_guard