Skip to content

Streaming Demo Changes #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 25, 2024
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pgadmin-data
__pycache__/
guard-rails-api-client
guardrails-custom-validators
guardrails_api/default.env
sdk
.coverage
htmlcov
Expand Down
6 changes: 4 additions & 2 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def register_config(config: Optional[str] = None):
SourceFileLoader("config", config_file_path).load_module()


def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None):
def create_app(
env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None
):
if os.environ.get("APP_ENVIRONMENT") != "production":
from dotenv import load_dotenv

Expand Down Expand Up @@ -86,7 +88,7 @@ def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Op

pg_client = PostgresClient()
pg_client.initialize(app)

cache_client = CacheClient()
cache_client.initialize(app)

Expand Down
51 changes: 35 additions & 16 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_guard = isinstance(export, Guard)
if is_guard:
guard_client.create_guard(export)

cache_client = CacheClient()


Expand Down Expand Up @@ -213,7 +213,6 @@ def validate(guard_name: str):
" calling guard(...)."
),
)

if llm_output is not None:
if stream:
raise HttpError(
Expand All @@ -226,32 +225,33 @@ def validate(guard_name: str):
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
# llm_api=llm_api,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
# api_key=openai_api_key,
*args,
**payload,
)

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
validation_output: ValidationOutcome = (
ValidationOutcome.from_guard_history(guard.history.last)
)

# ValidationOutcome(
# guard.history,
# validation_passed=result.validation_passed,
# validated_output=result.validated_output,
# raw_llm_output=result.raw_llm_output,
# )
yield validation_output, cast(ValidationOutcome, result)

def validate_streamer(guard_iter):
Expand All @@ -260,10 +260,21 @@ def validate_streamer(guard_iter):
for validation_output, result in guard_iter:
next_result = result
# next_validation_output = validation_output
fragment = json.dumps(validation_output.to_response())
fragment_dict = result.to_dict()
fragment_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
fragment = json.dumps(fragment_dict)
yield f"{fragment}\n"

call = guard.history.last
final_validation_output: ValidationOutcome = ValidationOutcome(
callId=call.id,
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
Expand All @@ -278,7 +289,16 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = final_validation_output.to_json()
final_output_dict = final_validation_output.to_dict()
final_output_dict["error_spans"] = list(
map(
lambda x: json.dumps(
{"start": x.start, "end": x.end, "reason": x.reason}
),
guard.error_spans_in_output(),
)
)
final_output_json = json.dumps(final_output_dict)
yield f"{final_output_json}\n"

return Response(
Expand Down Expand Up @@ -311,16 +331,15 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
serialized_history = [
call.to_dict() for call in guard.history
]
serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


@guards_bp.route("/<guard_name>/history/<call_id>", methods=["GET"])
@handle_error
def guard_history(guard_name: str, call_id: str):
if request.method == "GET":
cache_key = f"{guard_name}-{call_id}"
return cache_client.get(cache_key)
return cache_client.get(cache_key)
9 changes: 4 additions & 5 deletions guardrails_api/clients/cache_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ def __new__(cls):
if cls._instance is None:
cls._instance = super(CacheClient, cls).__new__(cls)
return cls._instance


def initialize(self, app):
self.cache = Cache(
app,
app,
config={
"CACHE_TYPE": "SimpleCache",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_THRESHOLD": 50
}
"CACHE_THRESHOLD": 50,
},
)

def get(self, key):
Expand All @@ -31,4 +30,4 @@ def delete(self, key):
self.cache.delete(key)

def clear(self):
self.cache.clear()
self.cache.clear()
2 changes: 0 additions & 2 deletions guardrails_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@
and guards will be persisted into postgres. In that case,
these guards will not be initialized.
"""

from guardrails import Guard # noqa
4 changes: 3 additions & 1 deletion guardrails_api/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ PYTHONUNBUFFERED=1
LOGLEVEL="INFO"
GUARDRAILS_LOG_LEVEL="INFO"
GUARDRAILS_PROCESS_COUNT=1
SELF_ENDPOINT=http://localhost:8000
API_KEY=***REMOVED***
HOST=http://localhost
PORT=8000
OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
9 changes: 9 additions & 0 deletions guardrails_api/start
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export APP_ENVIRONMENT=local
export PYTHONUNBUFFERED=1
export LOGLEVEL="INFO"
export GUARDRAILS_LOG_LEVEL="INFO"
export GUARDRAILS_PROCESS_COUNT=1
export SELF_ENDPOINT=http://localhost:8001
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

gunicorn --bind 0.0.0.0:8001 --timeout=5 --threads=10 "guardrails_api.app:create_app()"
Loading