diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 088d172..ed494c3 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -47,7 +47,9 @@ def register_config(config: Optional[str] = None): SourceFileLoader("config", config_file_path).load_module() -def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None): +def create_app( + env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None +): if os.environ.get("APP_ENVIRONMENT") != "production": from dotenv import load_dotenv @@ -86,7 +88,7 @@ def create_app(env: Optional[str] = None, config: Optional[str] = None, port: Op pg_client = PostgresClient() pg_client.initialize(app) - + cache_client = CacheClient() cache_client.initialize(app) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py index 6ebc790..29575fc 100644 --- a/guardrails_api/blueprints/guards.py +++ b/guardrails_api/blueprints/guards.py @@ -35,7 +35,7 @@ is_guard = isinstance(export, Guard) if is_guard: guard_client.create_guard(export) - + cache_client = CacheClient() @@ -213,7 +213,6 @@ def validate(guard_name: str): " calling guard(...)." ), ) - if llm_output is not None: if stream: raise HttpError( @@ -226,7 +225,6 @@ def validate(guard_name: str): num_reasks=num_reasks, prompt_params=prompt_params, llm_api=llm_api, - # api_key=openai_api_key, **payload, ) else: @@ -234,24 +232,26 @@ def validate(guard_name: str): def guard_streamer(): guard_stream = guard( - # llm_api=llm_api, + llm_api=llm_api, prompt_params=prompt_params, num_reasks=num_reasks, stream=stream, - # api_key=openai_api_key, *args, **payload, ) for result in guard_stream: # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ValidationOutcome( - result.validation_passed, - result.validated_output, - guard.history, - result.raw_llm_output, + validation_output: ValidationOutcome = ( + ValidationOutcome.from_guard_history(guard.history.last) ) + # ValidationOutcome( + # guard.history, + # validation_passed=result.validation_passed, + # validated_output=result.validated_output, + # raw_llm_output=result.raw_llm_output, + # ) yield validation_output, cast(ValidationOutcome, result) def validate_streamer(guard_iter): @@ -260,10 +260,21 @@ def validate_streamer(guard_iter): for validation_output, result in guard_iter: next_result = result # next_validation_output = validation_output - fragment = json.dumps(validation_output.to_response()) + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) + ) + fragment = json.dumps(fragment_dict) yield f"{fragment}\n" + call = guard.history.last final_validation_output: ValidationOutcome = ValidationOutcome( + callId=call.id, validation_passed=next_result.validation_passed, validated_output=next_result.validated_output, history=guard.history, @@ -278,7 +289,16 @@ def validate_streamer(guard_iter): # prompt_params=prompt_params, # result=next_result # ) - final_output_json = final_validation_output.to_json() + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = list( + map( + lambda x: json.dumps( + {"start": x.start, "end": x.end, "reason": x.reason} + ), + guard.error_spans_in_output(), + ) + ) + final_output_json = json.dumps(final_output_dict) yield f"{final_output_json}\n" return Response( @@ -311,16 +331,15 @@ def validate_streamer(guard_iter): # prompt_params=prompt_params, # result=result # ) - serialized_history = [ - call.to_dict() for call in guard.history - ] + serialized_history = [call.to_dict() for call in guard.history] cache_key = f"{guard.name}-{result.call_id}" cache_client.set(cache_key, serialized_history, 300) return result.to_dict() + @guards_bp.route("//history/", methods=["GET"]) @handle_error def guard_history(guard_name: str, call_id: str): if request.method == "GET": cache_key = f"{guard_name}-{call_id}" - return cache_client.get(cache_key) \ No newline at end of file + return cache_client.get(cache_key) diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index e24bcb5..dc45886 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -9,16 +9,15 @@ def __new__(cls): if cls._instance is None: cls._instance = super(CacheClient, cls).__new__(cls) return cls._instance - def initialize(self, app): self.cache = Cache( - app, + app, config={ "CACHE_TYPE": "SimpleCache", "CACHE_DEFAULT_TIMEOUT": 300, - "CACHE_THRESHOLD": 50 - } + "CACHE_THRESHOLD": 50, + }, ) def get(self, key): @@ -31,4 +30,4 @@ def delete(self, key): self.cache.delete(key) def clear(self): - self.cache.clear() \ No newline at end of file + self.cache.clear() diff --git a/guardrails_api/config.py b/guardrails_api/config.py index eab62ad..5801d80 100644 --- a/guardrails_api/config.py +++ b/guardrails_api/config.py @@ -7,5 +7,3 @@ and guards will be persisted into postgres. In that case, these guards will not be initialized. """ - -from guardrails import Guard # noqa diff --git a/guardrails_api/default.env b/guardrails_api/default.env index 6e37d0b..457209c 100644 --- a/guardrails_api/default.env +++ b/guardrails_api/default.env @@ -5,4 +5,4 @@ GUARDRAILS_LOG_LEVEL="INFO" GUARDRAILS_PROCESS_COUNT=1 HOST=http://localhost PORT=8000 -OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES \ No newline at end of file +OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES diff --git a/guardrails_api/start b/guardrails_api/start new file mode 100755 index 0000000..793fc4a --- /dev/null +++ b/guardrails_api/start @@ -0,0 +1,9 @@ +export APP_ENVIRONMENT=local +export PYTHONUNBUFFERED=1 +export LOGLEVEL="INFO" +export GUARDRAILS_LOG_LEVEL="INFO" +export GUARDRAILS_PROCESS_COUNT=1 +export SELF_ENDPOINT=http://localhost:8000 +export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + +gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()" diff --git a/guardrails_api/utils/get_llm_callable.py b/guardrails_api/utils/get_llm_callable.py index 72345ee..49d7737 100644 --- a/guardrails_api/utils/get_llm_callable.py +++ b/guardrails_api/utils/get_llm_callable.py @@ -6,32 +6,24 @@ get_static_openai_acreate_func, get_static_openai_chat_acreate_func, ) -from guardrails_api_client.models.validate_payload import ( - ValidatePayload, -) from guardrails_api_client.models.llm_resource import LLMResource def get_llm_callable( llm_api: str, ) -> Union[Callable, Callable[[Any], Awaitable[Any]]]: - try: - model = ValidatePayload(llm_api) - # TODO: Add error handling and throw 400 - if model is LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE: - return get_static_openai_create_func() - elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE: - return get_static_openai_chat_create_func() - elif model is LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE: - return get_static_openai_acreate_func() - elif model is LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE: - return get_static_openai_chat_acreate_func() - elif model is LLMResource.LITELLM_DOT_COMPLETION: - return litellm.completion - elif model is LLMResource.LITELLM_DOT_ACOMPLETION: - return litellm.acompletion - - else: - pass - except Exception: + # TODO: Add error handling and throw 400 + if llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE.value: + return get_static_openai_create_func() + elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE.value: + return get_static_openai_chat_create_func() + elif llm_api == LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE.value: + return get_static_openai_acreate_func() + elif llm_api == LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE.value: + return get_static_openai_chat_acreate_func() + elif llm_api == LLMResource.LITELLM_DOT_COMPLETION.value: + return litellm.completion + elif llm_api == LLMResource.LITELLM_DOT_ACOMPLETION.value: + return litellm.acompletion + else: pass diff --git a/sample-config.py b/sample-config.py index 7df5938..2677590 100644 --- a/sample-config.py +++ b/sample-config.py @@ -1,71 +1,31 @@ -''' +""" All guards defined here will be initialized, if and only if the application is using in memory guards. The application will use in memory guards if pg_host is left -undefined. Otherwise, a postgres instance will be started +undefined. Otherwise, a postgres instance will be started and guards will be persisted into postgres. In that case, these guards will not be initialized. -''' +""" from guardrails import Guard -from guardrails.hub import RegexMatch, ValidChoices, ValidLength #, RestrictToTopic - -name_case = Guard( - name='name-case', - description='Checks that a string is in Name Case format.' -).use( - RegexMatch(regex="^(?:[A-Z][^\s]*\s?)+$") -) - -all_caps = Guard( - name='all-caps', - description='Checks that a string is all capital.' -).use( - RegexMatch(regex="^[A-Z\\s]*$") -) - -lower_case = Guard( - name='lower-case', - description='Checks that a string is all lowercase.' -).use( - RegexMatch(regex="^[a-z\\s]*$") -).use( - ValidLength(1, 100) -).use( - ValidChoices(["music", "cooking", "camping", "outdoors"]) +from guardrails.hub import ( + DetectPII, + CompetitorCheck ) -print(lower_case.to_json()) - +no_guards = Guard() +no_guards.name = "No Guards" +output_guard = Guard() +output_guard.name = "Output Guard" +output_guard.use_many( + DetectPII( + pii_entities='pii' + ), + CompetitorCheck( + competitors=['OpenAI', 'Anthropic'] + ) +) -# valid_topics = ["music", "cooking", "camping", "outdoors"] -# invalid_topics = ["sports", "work", "ai"] -# all_topics = [*valid_topics, *invalid_topics] - -# def custom_llm (text: str, *args, **kwargs): -# return [ -# { -# "name": t, -# "present": (t in text), -# "confidence": 5 -# } -# for t in all_topics -# ] - -# custom_code_guard = Guard( -# name='custom', -# description='Uses a custom llm for RestrictToTopic' -# ).use( -# RestrictToTopic( -# valid_topics=valid_topics, -# invalid_topics=invalid_topics, -# llm_callable=custom_llm, -# disable_classifier=True, -# disable_llm=False, -# # Pass this so it doesn't load the bart model -# classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer", -# ) -# ) \ No newline at end of file diff --git a/tests/blueprints/test_guards.py b/tests/blueprints/test_guards.py index b12a49d..d33a93a 100644 --- a/tests/blueprints/test_guards.py +++ b/tests/blueprints/test_guards.py @@ -45,7 +45,12 @@ def test_route_setup(mocker): from guardrails_api.blueprints.guards import guards_bp assert guards_bp.route_call_count == 4 - assert guards_bp.routes == ["/", "/", "//validate", "//history/"] + assert guards_bp.routes == [ + "/", + "/", + "//validate", + "//history/", + ] def test_guards__get(mocker): @@ -83,7 +88,8 @@ def test_guards__post_pg(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("guardrails_api.blueprints.guards.request", mock_request) mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", return_value=mock_guard + "guardrails_api.blueprints.guards.GuardStruct.from_dict", + return_value=mock_guard, ) mock_create_guard = mocker.patch( "guardrails_api.blueprints.guards.guard_client.create_guard", @@ -185,7 +191,8 @@ def test_guard__put_pg(mocker): mocker.patch("guardrails_api.blueprints.guards.request", mock_request) mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", return_value=mock_guard + "guardrails_api.blueprints.guards.GuardStruct.from_dict", + return_value=mock_guard, ) mock_upsert_guard = mocker.patch( "guardrails_api.blueprints.guards.guard_client.upsert_guard", @@ -387,10 +394,8 @@ def test_validate__parse(mocker): "guardrails_api.blueprints.guards.guard_client.get_guard", return_value=mock_guard, ) - - mocker.patch( - "guardrails_api.blueprints.guards.CacheClient.set" - ) + + mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) @@ -451,7 +456,7 @@ def test_validate__call(mocker): call_id="mock-call-id", raw_llm_output="Hello world!", validated_output=None, - validation_passed=False + validation_passed=False, ) mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") @@ -484,12 +489,8 @@ def test_validate__call(mocker): "guardrails_api.blueprints.guards.get_llm_callable", return_value="openai.Completion.create", ) - - mocker.patch( - "guardrails_api.blueprints.guards.CacheClient.set" - ) - - + + mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) diff --git a/tests/clients/test_pg_guard_client.py b/tests/clients/test_pg_guard_client.py index b0c9813..0b94224 100644 --- a/tests/clients/test_pg_guard_client.py +++ b/tests/clients/test_pg_guard_client.py @@ -174,7 +174,9 @@ def test_get_guards(mocker): guards = [guard_one, guard_two] mock_all.return_value = guards - mock_from_guard_item = mocker.patch("guardrails_api.clients.pg_guard_client.from_guard_item") + mock_from_guard_item = mocker.patch( + "guardrails_api.clients.pg_guard_client.from_guard_item" + ) mock_from_guard_item.side_effect = [guard_one, guard_two] from guardrails_api.clients.pg_guard_client import PGGuardClient @@ -208,7 +210,9 @@ def test_create_guard(mocker): add_spy = mocker.spy(mock_pg_client.db.session, "add") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("guardrails_api.clients.pg_guard_client.from_guard_item") + mock_from_guard_item = mocker.patch( + "guardrails_api.clients.pg_guard_client.from_guard_item" + ) mock_from_guard_item.return_value = mock_guard from guardrails_api.clients.pg_guard_client import PGGuardClient