Skip to content

Streaming Demo Changes #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 25, 2024
37 changes: 22 additions & 15 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def validate(guard_name: str):
" calling guard(...)."
),
)

if llm_output is not None:
if stream:
raise HttpError(
Expand All @@ -226,32 +225,32 @@ def validate(guard_name: str):
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
# llm_api=llm_api,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
# api_key=openai_api_key,
*args,
**payload,
)

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
)

validation_output: ValidationOutcome = ValidationOutcome.from_guard_history(guard.history.last)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
# validated_output=result.validated_output,
# raw_llm_output=result.raw_llm_output,
# )
print('RESULT', result)
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
Expand All @@ -260,7 +259,12 @@ def validate_streamer(guard_iter):
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment = json.dumps(validation_output.to_response())
fragment_dict = result.to_dict()
fragment_dict['error_spans'] = guard.error_spans_in_output()
fragment = json.dumps(fragment_dict)
print('fragment!', fragment)
print(guard.error_spans_in_output())
print(guard.history.last.iterations.last.outputs.validator_logs)
yield f"{fragment}\n"

final_validation_output: ValidationOutcome = ValidationOutcome(
Expand All @@ -278,9 +282,12 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = final_validation_output.to_json()
yield f"{final_output_json}\n"

final_output_dict = final_validation_output.to_dict()
final_output_dict['error_spans'] = guard.error_spans_in_output()
print('error spans', guard.error_spans_in_output())
final_output_json = json.dumps(final_output_dict)
print('final output', final_output_json)
yield f"{final_validation_output.to_json()}\n"
return Response(
stream_with_context(validate_streamer(guard_streamer())),
content_type="application/json",
Expand Down
22 changes: 21 additions & 1 deletion guardrails_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,25 @@
and guards will be persisted into postgres. In that case,
these guards will not be initialized.
"""
# from guardrails_hub import AtomicFactuality # noqa
from guardrails import Guard
from guardrails.hub import (
DetectPII,
CompetitorCheck
)


no_guards = Guard()
no_guards.name = "No Guards"

output_guard = Guard()
output_guard.name = "Output Guard"
output_guard.use_many(
DetectPII(
pii_entities='pii'
),
CompetitorCheck(
competitors=['OpenAI', 'Anthropic']
)
)

from guardrails import Guard # noqa
36 changes: 16 additions & 20 deletions guardrails_api/utils/get_llm_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,23 @@
)
from guardrails_api_client.models.llm_resource import LLMResource


litellm.set_verbose=True
def get_llm_callable(
llm_api: str,
) -> Union[Callable, Callable[[Any], Awaitable[Any]]]:
try:
model = ValidatePayload(llm_api)
# TODO: Add error handling and throw 400
if model is LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE:
return get_static_openai_create_func()
elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE:
return get_static_openai_chat_create_func()
elif model is LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE:
return get_static_openai_acreate_func()
elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE:
return get_static_openai_chat_acreate_func()
elif model is LLMResource.LITELLM_DOT_COMPLETION:
return litellm.completion
elif model is LLMResource.LITELLM_DOT_ACOMPLETION:
return litellm.acompletion

else:
pass
except Exception:
# TODO: Add error handling and throw 400
if llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE.value:
return get_static_openai_create_func()
elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE.value:
return get_static_openai_chat_create_func()
elif llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE.value:
return get_static_openai_acreate_func()
elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE.value:
return get_static_openai_chat_acreate_func()
elif llm_api == LLMResource.LITELLM_DOT_COMPLETION.value:
return litellm.completion
elif llm_api == LLMResource.LITELLM_DOT_ACOMPLETION.value:
return litellm.acompletion
else:
pass

Loading