Skip to content

add support for streaming validation exceptions and exception handling, disable history support for history because of poor multi node support #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion guardrails_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0-alpha1"
__version__ = "0.1.0-alpha2"
127 changes: 77 additions & 50 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

router = APIRouter()

def guard_history_is_enabled():
return os.environ.get("GUARD_HISTORY_ENABLED", "true").lower() == "true"

@router.get("/guards")
@handle_error
Expand Down Expand Up @@ -125,7 +127,12 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)

if not stream:
validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload)
execution = guard(num_reasks=0, **payload)
if inspect.iscoroutine(execution):
validation_outcome: ValidationOutcome = await execution
else:
validation_outcome: ValidationOutcome = execution

llm_response = guard.history.last.iterations.last.outputs.llm_response_info
result = outcome_to_chat_completion(
validation_outcome=validation_outcome,
Expand All @@ -136,13 +143,17 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
else:

async def openai_streamer():
guard_stream = await guard(num_reasks=0, **payload)
async for result in guard_stream:
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
yield f"data: {chunk}\n\n"
yield "\n"
try:
guard_stream = await guard(num_reasks=0, **payload)
async for result in guard_stream:
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
yield f"data: {chunk}\n\n"
yield "\n"
except Exception as e:
yield f"data: {json.dumps({'error': {'message':str(e)}})}\n\n"
yield "\n"

return StreamingResponse(openai_streamer(), media_type="text/event-stream")

Expand Down Expand Up @@ -196,58 +207,75 @@ async def validate(guard_name: str, request: Request):
raise HTTPException(
status_code=400, detail="Streaming is not supported for parse calls!"
)
result: ValidationOutcome = guard.parse(
execution = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
**payload,
)
if inspect.iscoroutine(execution):
result: ValidationOutcome = await execution
else:
result: ValidationOutcome = execution
else:
if stream:

async def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
for result in guard_stream:
validation_output = ValidationOutcome.from_guard_history(
guard.history.last
call = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
*args,
**payload,
)
yield validation_output, result
is_async = inspect.iscoroutine(call)
if is_async:
guard_stream = await call
async for result in guard_stream:
validation_output = ValidationOutcome.from_guard_history(
guard.history.last
)
yield validation_output, result
else:
guard_stream = call
for result in guard_stream:
validation_output = ValidationOutcome.from_guard_history(
guard.history.last
)
yield validation_output, result

async def validate_streamer(guard_iter):
async for validation_output, result in guard_iter:
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = [
try:
async for validation_output, result in guard_iter:
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = [
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
for x in guard.error_spans_in_output()
]
yield json.dumps(fragment_dict) + "\n"

call = guard.history.last
final_validation_output = ValidationOutcome(
callId=call.id,
validation_passed=result.validation_passed,
validated_output=result.validated_output,
history=guard.history,
raw_llm_output=result.raw_llm_output,
)
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = [
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
for x in guard.error_spans_in_output()
]
yield json.dumps(fragment_dict) + "\n"

call = guard.history.last
final_validation_output = ValidationOutcome(
callId=call.id,
validation_passed=result.validation_passed,
validated_output=result.validated_output,
history=guard.history,
raw_llm_output=result.raw_llm_output,
)
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = [
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
for x in guard.error_spans_in_output()
]
yield json.dumps(final_output_dict) + "\n"

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
await cache_client.set(cache_key, serialized_history, 300)
yield json.dumps(final_output_dict) + "\n"
except Exception as e:
yield json.dumps({"error": {"message": str(e)}}) + "\n"

if guard_history_is_enabled():
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{final_validation_output.call_id}"
await cache_client.set(cache_key, serialized_history, 300)

return StreamingResponse(
validate_streamer(guard_streamer()), media_type="application/json"
Expand All @@ -260,15 +288,14 @@ async def validate_streamer(guard_iter):
*args,
**payload,
)

if inspect.iscoroutine(execution):
result: ValidationOutcome = await execution
else:
result: ValidationOutcome = execution

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
await cache_client.set(cache_key, serialized_history, 300)
if guard_history_is_enabled():
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
await cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


Expand Down
8 changes: 2 additions & 6 deletions tests/api/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from tests.mocks.mock_guard_client import MockGuardStruct
from guardrails_api.api.guards import router as guards_router


import asyncio

# TODO: Should we mock this somehow?
# Right now it's just empty, but it technically does a file read
register_config()
Expand Down Expand Up @@ -347,9 +344,8 @@ def test_openai_v1_chat_completions__call(mocker):
)

mock___call__ = mocker.patch.object(MockGuardStruct, "__call__")
future = asyncio.Future()
future.set_result(mock_outcome)
mock___call__.return_value = future

mock___call__.return_value = mock_outcome

mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict")
mock_from_dict.return_value = mock_guard
Expand Down