Skip to content

Streaming Demo Changes #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 25, 2024
6 changes: 4 additions & 2 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def register_config(config: Optional[str] = None):
SourceFileLoader("config", config_file_path).load_module()


def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None):
def create_app(
env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None
):
if os.environ.get("APP_ENVIRONMENT") != "production":
from dotenv import load_dotenv

Expand Down Expand Up @@ -86,7 +88,7 @@ def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Op

pg_client = PostgresClient()
pg_client.initialize(app)

cache_client = CacheClient()
cache_client.initialize(app)

Expand Down
51 changes: 35 additions & 16 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_guard = isinstance(export, Guard)
if is_guard:
guard_client.create_guard(export)

cache_client = CacheClient()


Expand Down Expand Up @@ -213,7 +213,6 @@ def validate(guard_name: str):
" calling guard(...)."
),
)

if llm_output is not None:
if stream:
raise HttpError(
Expand All @@ -226,32 +225,33 @@ def validate(guard_name: str):
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
# llm_api=llm_api,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
# api_key=openai_api_key,
*args,
**payload,
)

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
# validated_output=result.validated_output,
# raw_llm_output=result.raw_llm_output,
# )
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
Expand All @@ -260,10 +260,21 @@ def validate_streamer(guard_iter):
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment = json.dumps(validation_output.to_response())
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
Expand All @@ -278,7 +289,16 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = final_validation_output.to_json()
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)
yield f"{final_output_json}\n"

return Response(
Expand Down Expand Up @@ -311,16 +331,15 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
serialized_history = [
call.to_dict() for call in guard.history
]
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


@guards_bp.route("/<guard_name>/history/<call_id>", methods=["GET"])
@handle_error
def guard_history(guard_name: str, call_id: str):
if request.method == "GET":
cache_key = f"{guard_name}-{call_id}"
return cache_client.get(cache_key)
return cache_client.get(cache_key)
9 changes: 4 additions & 5 deletions guardrails_api/clients/cache_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ def __new__(cls):
if cls._instance is None:
cls._instance = super(CacheClient, cls).__new__(cls)
return cls._instance


def initialize(self, app):
self.cache = Cache(
app,
app,
config={
"CACHE_TYPE": "SimpleCache",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_THRESHOLD": 50
}
"CACHE_THRESHOLD": 50,
},
)

def get(self, key):
Expand All @@ -31,4 +30,4 @@ def delete(self, key):
self.cache.delete(key)

def clear(self):
self.cache.clear()
self.cache.clear()
2 changes: 0 additions & 2 deletions guardrails_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@
and guards will be persisted into postgres. In that case,
these guards will not be initialized.
"""

from guardrails import Guard # noqa
2 changes: 1 addition & 1 deletion guardrails_api/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ GUARDRAILS_LOG_LEVEL="INFO"
GUARDRAILS_PROCESS_COUNT=1
HOST=http://localhost
PORT=8000
OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
9 changes: 9 additions & 0 deletions guardrails_api/start
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export APP_ENVIRONMENT=local
export PYTHONUNBUFFERED=1
export LOGLEVEL="INFO"
export GUARDRAILS_LOG_LEVEL="INFO"
export GUARDRAILS_PROCESS_COUNT=1
export SELF_ENDPOINT=http://localhost:8000
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()"
36 changes: 14 additions & 22 deletions guardrails_api/utils/get_llm_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,24 @@
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
)
from guardrails_api_client.models.validate_payload import (
ValidatePayload,
)
from guardrails_api_client.models.llm_resource import LLMResource


def get_llm_callable(
llm_api: str,
) -> Union[Callable, Callable[[Any], Awaitable[Any]]]:
try:
model = ValidatePayload(llm_api)
# TODO: Add error handling and throw 400
if model is LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE:
return get_static_openai_create_func()
elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE:
return get_static_openai_chat_create_func()
elif model is LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE:
return get_static_openai_acreate_func()
elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE:
return get_static_openai_chat_acreate_func()
elif model is LLMResource.LITELLM_DOT_COMPLETION:
return litellm.completion
elif model is LLMResource.LITELLM_DOT_ACOMPLETION:
return litellm.acompletion

else:
pass
except Exception:
# TODO: Add error handling and throw 400
if llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE.value:
return get_static_openai_create_func()
elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE.value:
return get_static_openai_chat_create_func()
elif llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE.value:
return get_static_openai_acreate_func()
elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE.value:
return get_static_openai_chat_acreate_func()
elif llm_api == LLMResource.LITELLM_DOT_COMPLETION.value:
return litellm.completion
elif llm_api == LLMResource.LITELLM_DOT_ACOMPLETION.value:
return litellm.acompletion
else:
pass
76 changes: 18 additions & 58 deletions sample-config.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,31 @@
'''
"""
All guards defined here will be initialized, if and only if
the application is using in memory guards.

The application will use in memory guards if pg_host is left
undefined. Otherwise, a postgres instance will be started
undefined. Otherwise, a postgres instance will be started
and guards will be persisted into postgres. In that case,
these guards will not be initialized.
'''
"""

from guardrails import Guard
from guardrails.hub import RegexMatch, ValidChoices, ValidLength #, RestrictToTopic

name_case = Guard(
name='name-case',
description='Checks that a string is in Name Case format.'
).use(
RegexMatch(regex="^(?:[A-Z][^\s]*\s?)+$")
)

all_caps = Guard(
name='all-caps',
description='Checks that a string is all capital.'
).use(
RegexMatch(regex="^[A-Z\\s]*$")
)

lower_case = Guard(
name='lower-case',
description='Checks that a string is all lowercase.'
).use(
RegexMatch(regex="^[a-z\\s]*$")
).use(
ValidLength(1, 100)
).use(
ValidChoices(["music", "cooking", "camping", "outdoors"])
from guardrails.hub import (
DetectPII,
CompetitorCheck
)

print(lower_case.to_json())


no_guards = Guard()
no_guards.name = "No Guards"

output_guard = Guard()
output_guard.name = "Output Guard"
output_guard.use_many(
DetectPII(
pii_entities='pii'
),
CompetitorCheck(
competitors=['OpenAI', 'Anthropic']
)
)

# valid_topics = ["music", "cooking", "camping", "outdoors"]
# invalid_topics = ["sports", "work", "ai"]
# all_topics = [*valid_topics, *invalid_topics]

# def custom_llm (text: str, *args, **kwargs):
# return [
# {
# "name": t,
# "present": (t in text),
# "confidence": 5
# }
# for t in all_topics
# ]

# custom_code_guard = Guard(
# name='custom',
# description='Uses a custom llm for RestrictToTopic'
# ).use(
# RestrictToTopic(
# valid_topics=valid_topics,
# invalid_topics=invalid_topics,
# llm_callable=custom_llm,
# disable_classifier=True,
# disable_llm=False,
# # Pass this so it doesn't load the bart model
# classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
# )
# )
29 changes: 15 additions & 14 deletions tests/blueprints/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def test_route_setup(mocker):
from guardrails_api.blueprints.guards import guards_bp

assert guards_bp.route_call_count == 4
assert guards_bp.routes == ["/", "/<guard_name>", "/<guard_name>/validate", "/<guard_name>/history/<call_id>"]
assert guards_bp.routes == [
"/",
"/<guard_name>",
"/<guard_name>/validate",
"/<guard_name>/history/<call_id>",
]


def test_guards__get(mocker):
Expand Down Expand Up @@ -83,7 +88,8 @@ def test_guards__post_pg(mocker):
mocker.patch("flask.Blueprint", new=MockBlueprint)
mocker.patch("guardrails_api.blueprints.guards.request", mock_request)
mock_from_request = mocker.patch(
"guardrails_api.blueprints.guards.GuardStruct.from_dict", return_value=mock_guard
"guardrails_api.blueprints.guards.GuardStruct.from_dict",
return_value=mock_guard,
)
mock_create_guard = mocker.patch(
"guardrails_api.blueprints.guards.guard_client.create_guard",
Expand Down Expand Up @@ -185,7 +191,8 @@ def test_guard__put_pg(mocker):
mocker.patch("guardrails_api.blueprints.guards.request", mock_request)

mock_from_request = mocker.patch(
"guardrails_api.blueprints.guards.GuardStruct.from_dict", return_value=mock_guard
"guardrails_api.blueprints.guards.GuardStruct.from_dict",
return_value=mock_guard,
)
mock_upsert_guard = mocker.patch(
"guardrails_api.blueprints.guards.guard_client.upsert_guard",
Expand Down Expand Up @@ -387,10 +394,8 @@ def test_validate__parse(mocker):
"guardrails_api.blueprints.guards.guard_client.get_guard",
return_value=mock_guard,
)

mocker.patch(
"guardrails_api.blueprints.guards.CacheClient.set"
)

mocker.patch("guardrails_api.blueprints.guards.CacheClient.set")

# mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer)

Expand Down Expand Up @@ -451,7 +456,7 @@ def test_validate__call(mocker):
call_id="mock-call-id",
raw_llm_output="Hello world!",
validated_output=None,
validation_passed=False
validation_passed=False,
)

mock___call__ = mocker.patch.object(MockGuardStruct, "__call__")
Expand Down Expand Up @@ -484,12 +489,8 @@ def test_validate__call(mocker):
"guardrails_api.blueprints.guards.get_llm_callable",
return_value="openai.Completion.create",
)

mocker.patch(
"guardrails_api.blueprints.guards.CacheClient.set"
)



mocker.patch("guardrails_api.blueprints.guards.CacheClient.set")

# mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer)

Expand Down
Loading