diff --git a/.github/workflows/pr_qa.yml b/.github/workflows/pr_qa.yml index 4a32324..5624fc4 100644 --- a/.github/workflows/pr_qa.yml +++ b/.github/workflows/pr_qa.yml @@ -15,27 +15,18 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 - - name: Quality Checks + - name: Install Dependencies run: | - python -m pip install --upgrade pip; + python -m venv ./.venv + source ./.venv/bin/activate make install-lock; - opentelemetry-bootstrap -a install; - make qa; - - ## - ## NOTE: - ## The below checks for a file named .version-change-type that is used to automatically increment the version number during publish flows. - ## It also checks for changes in a file called RELEASENOTES.md that are used to auto-maintain a rolling CHANGELOG. - ## This is commented out for now until we finalize what the publish/release flow will look like. - ## + make install-dev; - # changeType=$(<.version-change-type) - # if [ -z "$changeType" ]; - # then - # echo "missing file .version-change-type!" - # exit 1 - # fi - # echo "Checking for release notes..." - # git fetch origin main ${{ github.event.pull_request.base.sha }}; - # diff=$(git diff -U0 ${{ github.event.pull_request.base.sha }} ${{ github.sha }} RELEASENOTES.md); - # if [ -z "$diff" ]; then echo "Missing release notes! exiting..."; exit 1; fi + curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml + npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml + + - name: Run Quality Checks + run: | + source ./.venv/bin/activate + + make qa; \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5d4b883..35218b2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,6 @@ opentelemetry-lambda-layer open-api-spec.json open-api-spec.yml .python-version -requirements-lock-old.txt \ No newline at end of file +requirements-lock-old.txt +models +opensearch \ No newline at end of file diff --git a/Dockerfile.dev b/Dockerfile.dev index c9ef609..3fc7a6c 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -48,4 +48,4 @@ COPY . . EXPOSE 8000 -CMD opentelemetry-instrument gunicorn --bind 0.0.0.0:8000 --timeout=90 --threads=10 "app:create_app()" +CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --threads=10 "app:create_app()" diff --git a/Dockerfile.heavy b/Dockerfile.heavy index 921fb34..02273f4 100644 --- a/Dockerfile.heavy +++ b/Dockerfile.heavy @@ -44,5 +44,5 @@ COPY . . EXPOSE 8000 -CMD opentelemetry-instrument gunicorn --bind 0.0.0.0:8000 --timeout=90 --threads=10 "app:create_app()" +CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --threads=10 "app:create_app()" # CMD gunicorn --forwarded-allow-ips="*" --bind 0.0.0.0:8000 --timeout=60 --threads=10 "app:create_app()" diff --git a/Dockerfile.lite b/Dockerfile.lite new file mode 100644 index 0000000..59094d6 --- /dev/null +++ b/Dockerfile.lite @@ -0,0 +1,55 @@ +FROM public.ecr.aws/docker/library/python:3.11.6-slim + +ARG GITHUB_TOKEN +ARG HF_TOKEN + +COPY .guardrailsrc /root/.guardrailsrc + +# COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.7.1 /lambda-adapter /opt/extensions/lambda-adapter +# COPY ./opentelemetry-lambda-layer /opt + +# Create app directory +WORKDIR /app + +# check the version +RUN python3 --version +# start the virtual environment +RUN python3 -m venv /opt/venv + +# Enable venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install git and curl +RUN apt-get update +RUN apt-get install -y git curl gcc jq + +# Copy the requirements file +COPY requirements*.txt . + +RUN curl https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem -o ./global-bundle.pem + +# Install app dependencies +RUN pip install -r requirements-lock.txt + +# Download punkt data +RUN python -m nltk.downloader -d /opt/nltk_data punkt + +# RUN guardrails hub install hub://guardrails/profanity_free +RUN guardrails hub install hub://guardrails/valid_length +RUN guardrails hub install hub://guardrails/lowercase +RUN guardrails hub install hub://guardrails/regex_match + +COPY ./custom-install ./custom-install + +RUN python ./custom-install/install.py + +# Freeze dependencies +RUN pip freeze > requirements-lock.txt + +# Copy the whole folder inside the Image filesystem +COPY . . + +EXPOSE 8000 + +CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --threads=10 --limit-request-line=0 --limit-request-fields=1000 --limit-request-field_size=0 "app:create_app()" +# CMD gunicorn --forwarded-allow-ips="*" --bind 0.0.0.0:8000 --timeout=60 --threads=10 "app:create_app()" diff --git a/Dockerfile.prod b/Dockerfile.prod index db4dc3c..cbd4e41 100644 --- a/Dockerfile.prod +++ b/Dockerfile.prod @@ -37,5 +37,5 @@ COPY . . EXPOSE 8000 -CMD opentelemetry-instrument gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" +CMD gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" # CMD gunicorn --forwarded-allow-ips="*" --bind 0.0.0.0:8000 --timeout=60 --threads=10 "app:create_app()" diff --git a/Makefile b/Makefile index 4f0149f..4113a87 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,11 @@ # Installs production dependencies install: pip install -r requirements.txt; + # This is a workaround because of this issue: https://github.com/open-telemetry/opentelemetry-python-contrib/issues/2053 + pip uninstall aiohttp -y + pip install opentelemetry-distro opentelemetry-bootstrap -a install + pip install aiohttp # Installs development dependencies install-dev: @@ -14,14 +18,11 @@ lock: install-lock: pip install -r requirements-lock.txt -build: - make install - -dev: - bash ./dev.sh +start: + bash ./local.sh -local: - python3 ./wsgi.py +infra: + docker compose --profile infra up --build env: if [ ! -d "./.venv" ]; then echo "Creating virtual environment..."; python3 -m venv ./.venv; fi; @@ -53,11 +54,11 @@ source: source ./.venv/bin/activate test: - python3 -m pytest ./tests + pytest ./tests test-cov: coverage run --source=./src -m pytest ./tests - coverage report --fail-under=70 + coverage report --fail-under=50 view-test-cov: coverage run --source=./src -m pytest ./tests diff --git a/README.md b/README.md index 5e19fd8..5e4cbdc 100644 --- a/README.md +++ b/README.md @@ -5,65 +5,46 @@ Docker compose stub of Guardrails as a Service We strongly encourage you to use a virtual environment when developing in python. To set one up for this project run the following: ```bash -pip install venv python3 -m venv ./.venv source ./.venv/bin/activate ``` Your terminal should now show that you are working from within the virtual environment. Now you can install the dependencies: ```bash -pip install -r requirements.txt +make install ``` And start the dev server: ```bash -bash dev.sh +make start ``` -Once all servces listed in the docker-compose configuration have launched you should be able to navigate to the following: -1. Swagger documenation for the guardrails-api at http://localhost:8000 -2. PgAdmin console at http://localhost:8888 +Once the service has launched, you should be able to navigate to the Swagger documenation for the guardrails-api at http://localhost:8000 -## Diagrams -Two diagrams for two different approaches to the infrastructure for this project. -## Open API Spec -One main OpenAPI Spec for the basic endpoints discussed. This was manually written for proof-of-concept purposes. For the long run you can consider tools like the following to generate this documentation for you if you prefere to define your data objects in code: - - [flask-rest-api](https://flask-rest-api.readthedocs.io/en/stable/openapi.html) - - [apispec](https://apispec.readthedocs.io/en/latest/index.html) +### Local Infrastructure +By default, the server will start with an in-memory store for Guards. As of June 4th, 2024 this store does not support write operations via the API. In order to utilize all CRUD operations you will need a postgres database running locally and you will need to provide the following environment variables (sane defaults included for demonstration purposes): +```sh +export PGPORT=5432 +export PGDATABASE=postgres +export PGHOST=localhost +export PGUSER=${PGUSER:-postgres} +export PGPASSWORD=${PGPASSWORD:-changeme} +``` -## Notes On Serialization And Pydantic Models -This API is intended to transact in serialized rail specs. By default that means not explicitly supporting Pydantic models via the main endpoints. There are two options to choose between if we are to retain support for Pydantic models: - -1. Pydantic Models are client side only -2. Pydantic Models are server side plugins - -Both of these approaches have merits and difficulties as listed below. - -### Client Only Pydantic Models -In this scenario the user can still use Pydantic Models to capture their schema structure, but the model must have some standard way of being serialized. In Pydantic 2.x, this is avaiable via the `model_dump` and `model_dump_json` methods depending on whether we need a dictionary or a json-encoded string. The advantages of this approach is that the serialization can take place client side in the sdk and the server does not have to worry about special cases, it continues to only accept JSON as input. What we do lose with this approach is any ability to define custom validations in the Pydantic model that can be executed by the server since the model only exists client side. One way around this is through the use of sockets. When the validation endpoint is called from the sdk, if the sdk knows there is a pydantic model involved, rather than performing a standard http request it can open a web socket or rpc connection instead. Then on the server, when it doesn't find a particular validation in its registry, it can use the socket to request the sdk to run that validation instead. This would require some addtional implementation to support this back and forth; we also wouldn't have an out-of-the-box way to collect telemetry on any validators run client-side. - -### Pydantic Model Plugins -In this scenrio the user publishes thier Pydantic model to some registry that can be accessed via pip. Then, in the railspec, the user specifies the module names for any custom validators. When validation is called, a prepare step is used to install any of these dependencies and import them so that the registration annotations can be run. This seems to be possible by using a combination of `pip` and `__import__` with some error handling. It might even be possible to [host our own pip repository](https://packaging.python.org/en/latest/guides/hosting-your-own-index/) and auto-publish the user's pydantic models for them to this private registry. This would put the user's custom code within the compliance scope of GuardRails rather than a public repository. - -The benefits of this approach is that the custom validations and pydantic models for the schema(s) can be handled "natively" on the server. There is no requirement for back-and-forth communication between the server and the sdk keeping each call atomic. There is also the potential to automate most of the addtional work this approach requires of the user. - -The difficulties of this approach are simply the additonal effort required to make the models available to the server. In the simpler case the user must make this additional effort by packaging and publishing their models. This also introduces the necessity to deal with various versions of the model through pip. This approach also would require additional effort to isolate these custom modules if the server is multi-tenant; likely requiring an isolated run environment be created during runtime which can get complicated. We would also want a cleanup process to run after validation to free up the space consumed by loading these modules. - -## Notes On Data Auditing -Outside of our typical telemetry (metrics, logs, and traces) we also want to capture audit-style details about both the Guard objects (railspec configurations) and the data consumed and generated during validation. - +You can create this database however you wish, but we do have a docker-compose configuration to stand up the database as well as a local opentelemetry stack. +To use this make sure you have docker installed, then run: -### Auditing The Guards/Railspecs -The first is relatively simple; we can use the concepts of Functions and Triggers already built in to Postgres to capture a snapshot of the Guard object when writes occur. These snapshots would be stored in their own table `guards_audit` and have all of the same columns as guards with the addition of a `replaced_on` timestamp column and `replaced_by` user id column to capture when it was updated and by whom. +`docker compose --profile db up --build` +to run just the database. -The most straight forward way to allow users to query previous versions of these objects are to implement a `GET /guards/{guard-name}` endpoint with an optional `as-of` datetime query parameter. This limits the results to only one guard and one version which should satisfy most auditing use cases. For a more flexible search, we could add `start-date` and `end-date` datetime query paramters to this same endpoint. +`docker compose --profile infra up --build` +to run the database and opentelemetry infrastructure -### Auditing Validations -Auditing the data used during validation is slightly less simple because it requires additional server-side implementation. -Depending on what all we need to capture, we might be able to just persist the history object already generated by the sdk to a new postgres table. In addtion to the json data the history object contains, we would also want to tag it with a start time, end time, whether the validations succeed or failed, the guard name, and if possible the trace-id for the request. Starting out we can probaly just save this to a dedicate postgres table via SqlAlchemy before returning the results to the user. In the long wrong though, since this data can potentially get rather large, we would want to send it to an asynchronous agent to handle while we return the result to the user. +or +`docker compose --profile all up --build` +to run everything including the guardrails-api -This data has the potential to grow to a large size, both per row and number of rows. The row, or really column, size is handled automatically if the data is stored in Postgres since it compresses large items automatically. The number of rows however is more worrisome. In the short term proper indexing should suffice. In the long term it might make sense to move this data to something more wide-column/key-value oriented like DynamoDB or OpenSearch. -Exposing this data to a consumer could be accomplished via an endpoint like `GET /guards/audit/{guard-name}`. Potential user specified query params are the time range (start and end datetimes) or the trace-id but not both since the trace implicitly specifies a time range/single document. This endpoint would be a good candidate for GraphQL especially if the user wanted to query these audit records along with other telemetry. \ No newline at end of file +The last option is useful when checking that everything will work as planned in a more productionized environment. When developing, it's generally faster to just run the minimum infrastructure you need via docker and run the api on a bare process with the `make start` command. \ No newline at end of file diff --git a/app.py b/app.py index aad8666..e863d84 100644 --- a/app.py +++ b/app.py @@ -4,7 +4,9 @@ from werkzeug.middleware.proxy_fix import ProxyFix from urllib.parse import urlparse from guardrails import configure_logging -# from opentelemetry.instrumentation.flask import FlaskInstrumentor +from opentelemetry.instrumentation.flask import FlaskInstrumentor +from src.clients.postgres_client import postgres_is_enabled +from src.otel import otel_is_disabled, initialize class ReverseProxied(object): @@ -19,6 +21,11 @@ def __call__(self, environ, start_response): def create_app(): + if os.environ.get("APP_ENVIRONMENT") != "production": + from dotenv import load_dotenv + + load_dotenv() + app = Flask(__name__) app.config["APPLICATION_ROOT"] = "/" @@ -31,12 +38,16 @@ def create_app(): guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO") configure_logging(log_level=guardrails_log_level) - # FlaskInstrumentor().instrument_app(app) + if not otel_is_disabled(): + FlaskInstrumentor().instrument_app(app) + initialize() - from src.clients.postgres_client import PostgresClient + # if no pg_host is set, don't set up postgres + if postgres_is_enabled(): + from src.clients.postgres_client import PostgresClient - pg_client = PostgresClient() - pg_client.initialize(app) + pg_client = PostgresClient() + pg_client.initialize(app) from src.blueprints.root import root_bp from src.blueprints.guards import guards_bp diff --git a/compose.yml b/compose.yml new file mode 100644 index 0000000..712810b --- /dev/null +++ b/compose.yml @@ -0,0 +1,143 @@ +services: + postgres: + profiles: ["all", "db", "infra"] + image: ankane/pgvector + environment: + POSTGRES_USER: ${PGUSER:-postgres} + POSTGRES_PASSWORD: ${PGPASSWORD:-changeme} + POSTGRES_DATA: /data/postgres + volumes: + - ./postgres:/data/postgres + ports: + - "5432:5432" + restart: always + pgadmin: + profiles: ["all", "db", "infra"] + image: dpage/pgadmin4 + logging: + driver: none + restart: always + ports: + - "8088:80" + environment: + PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com" + PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme} + PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json + volumes: + - ./pgadmin-data:/var/lib/pgadmin + depends_on: + - postgres + guardrails-api: + profiles: ["all", "api"] + image: guardrails-api:latest + build: + context: . + dockerfile: Dockerfile.dev + args: + PORT: "8000" + ports: + - "8000:8000" + environment: + APP_ENVIRONMENT: local + AWS_PROFILE: dev + AWS_DEFAULT_REGION: us-east-1 + PGPORT: 5432 + PGDATABASE: postgres + PGHOST: postgres + PGUSER: ${PGUSER:-postgres} + PGPASSWORD: ${PGPASSWORD:-changeme} + NLTK_DATA: /opt/nltk_data + OTEL_SERVICE_NAME: guardrails-api + OTEL_EXPORTER_OTLP_ENDPOINT: http://otel-collector:4317 + OTEL_TRACES_EXPORTER: otlp #,console + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Accept-Encoding,User-Agent,Referer" + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Last-Modified,Content-Type" + OTEL_METRICS_EXPORTER: otlp #,console + # # Disable logging for now to reduce noise + # OTEL_LOGS_EXPORTER: otlp,console + # OTEL_PYTHON_LOG_CORRELATION: true + # OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED: true + # OTEL_PYTHON_LOG_LEVEL: INFO + PYTHONUNBUFFERED: 1 + LOGLEVEL: INFO + # # Use the below env vars if we ever split up sinks + # OTEL_EXPORTER_OTLP_TRACES_ENDPOINT: http://otel-collector:4317 + # OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: http://otel-collector:4317 + # OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: http://otel-collector:4317 + # OTEL_PYTHON_LOG_FORMAT: "%(msg)s [span_id=%(span_id)s]" + depends_on: + - postgres + - otel-collector + opensearch-node1: + profiles: ["all", "otel", "infra"] + image: opensearchproject/opensearch:latest + container_name: opensearch-node1 + environment: + - cluster.name=opensearch-cluster # Name the cluster + - node.name=opensearch-node1 # Name the node that will run in this container + - discovery.type=single-node + # - discovery.seed_hosts=opensearch-node1 # Nodes to look for when discovering the cluster + # - cluster.initial_cluster_manager_nodes=opensearch-node1 # Nodes eligibile to serve as cluster manager + - bootstrap.memory_lock=true # Disable JVM heap memory swapping + - "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" # Set min and max JVM heap sizes to at least 50% of system RAM + - DISABLE_INSTALL_DEMO_CONFIG=true # Prevents execution of bundled demo script which installs demo certificates and security configurations to OpenSearch + - DISABLE_SECURITY_PLUGIN=true # Disables Security plugin + ulimits: + memlock: + soft: -1 # Set memlock to unlimited (no soft or hard limit) + hard: -1 + nofile: + soft: 65536 # Maximum number of open files for the opensearch user - set to at least 65536 + hard: 65536 + volumes: + - ./opensearch/opensearch-data1:/usr/share/opensearch/data # Creates volume called opensearch-data1 and mounts it to the container + ports: + - 9200:9200 # REST API + - 9600:9600 # Performance Analyzer + opensearch-dashboards: + profiles: ["all", "otel", "infra"] + image: opensearchproject/opensearch-dashboards:latest + container_name: opensearch-dashboards + ports: + - 5601:5601 # Map host port 5601 to container port 5601 + expose: + - "5601" # Expose port 5601 for web access to OpenSearch Dashboards + environment: + - OPENSEARCH_HOSTS=["http://opensearch-node1:9200"] + - DISABLE_SECURITY_DASHBOARDS_PLUGIN=true # disables security dashboards plugin in OpenSearch Dashboards + data-prepper: + profiles: ["all", "otel", "infra"] + restart: unless-stopped + container_name: data-prepper + image: opensearchproject/data-prepper:latest + volumes: + - ./configs/pipelines.yml:/usr/share/data-prepper/pipelines/pipelines.yaml + - ./configs/data-prepper-config.yml:/usr/share/data-prepper/config/data-prepper-config.yaml + ports: + - 21890:21890 + - 21891:21891 + - 21892:21892 + expose: + - "21890" + - "21891" + - "21892" + depends_on: + - opensearch-node1 + otel-collector: + profiles: ["all", "otel", "infra"] + restart: unless-stopped + container_name: otel-collector + image: otel/opentelemetry-collector:latest + command: ["--config=/etc/otel-collector-config.yml"] + volumes: + - ./configs/otel-collector-config.yml:/etc/otel-collector-config.yml + ports: + - 1888:1888 # pprof extension + # - 8888:8888 # Prometheus metrics exposed by the collector + # - 8889:8889 # Prometheus exporter metrics + - 13133:13133 # health_check extension + - 4317:4317 # OTLP gRPC receiver + # - 4318:4318 # OTLP http receiver + # - 55679:55679 # zpages extension + depends_on: + - data-prepper \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..5f67b8e --- /dev/null +++ b/config.py @@ -0,0 +1,11 @@ +''' +All guards defined here will be initialized, if and only if +the application is using in memory guards. + +The application will use in memory guards if pg_host is left +undefined. Otherwise, a postgres instance will be started +and guards will be persisted into postgres. In that case, +these guards will not be initialized. +''' + +from guardrails import Guard \ No newline at end of file diff --git a/custom-install/install.py b/custom-install/install.py new file mode 100644 index 0000000..85ea608 --- /dev/null +++ b/custom-install/install.py @@ -0,0 +1,113 @@ +import os +import sys +import logging +import json +from typing import Any, Dict +from rich.console import Console +from guardrails.cli.hub.install import ( + get_site_packages_location, + install_hub_module, + run_post_install, + add_to_hub_inits +) +from guardrails.cli.server.module_manifest import ModuleManifest +from string import Template + +console = Console() + +os.environ[ + "COLOREDLOGS_LEVEL_STYLES" +] = "spam=white,faint;success=green,bold;debug=magenta;verbose=blue;notice=cyan,bold;warning=yellow;error=red;critical=background=red" # noqa +LEVELS = { + "SPAM": 5, + "VERBOSE": 15, + "NOTICE": 25, + "SUCCESS": 35, +} +for key in LEVELS: + logging.addLevelName(LEVELS.get(key), key) # type: ignore +logger = logging.getLogger("custom-install") + + +def load_manifest(fileName: str) -> Dict[str, Any]: + with open(f"custom-install/manifests/{fileName}") as manifest_file: + content = manifest_file.read() + return json.loads(content) + +custom_manifests = { + "guardrails/provenance_llm": load_manifest("provenance-llm.json"), + "guardrails/detect_pii": load_manifest("detect-pii.json"), + "guardrails/competitor_check": load_manifest("competitor-check.json"), + "guardrails/many_shot_jailbreak": load_manifest("jailbreak.json"), + "tryolabs/restricttotopic": load_manifest("restrict-to-topic.json"), +} + +def get_validator_manifest(module_name) -> ModuleManifest: + manifest = custom_manifests.get(module_name, {}) + return ModuleManifest.from_dict(manifest) + +def custom_install(package_uri: str): + """Install a validator from the Hub.""" + if not package_uri.startswith("hub://"): + logger.error("Invalid URI!") + sys.exit(1) + + console.print(f"\nInstalling {package_uri}...\n") + logger.log( + level=LEVELS.get("SPAM"), msg=f"Installing {package_uri}..." # type: ignore + ) + + # Validation + module_name = package_uri.replace("hub://", "") + + # Prep + with console.status("Fetching manifest", spinner="bouncingBar"): + module_manifest = get_validator_manifest(module_name) + site_packages = get_site_packages_location() + + # Install + with console.status("Downloading dependencies", spinner="bouncingBar"): + install_hub_module(module_manifest, site_packages) + + # Post-install + with console.status("Running post-install setup", spinner="bouncingBar"): + run_post_install(module_manifest, site_packages) + add_to_hub_inits(module_manifest, site_packages) + + success_message_cli = Template( + """✅Successfully installed ${module_name}! + +[bold]Import validator:[/bold] +from guardrails.hub import ${export} + +[bold]Get more info:[/bold] +https://hub.guardrailsai.com/validator/${id} +""" + ).safe_substitute( + module_name=package_uri, + id=module_manifest.id, + export=module_manifest.exports[0], + ) + success_message_logger = Template( + """✅Successfully installed ${module_name}! + +Import validator: +from guardrails.hub import ${export} + +Get more info: +https://hub.guardrailsai.com/validator/${id} +""" + ).safe_substitute( + module_name=package_uri, + id=module_manifest.id, + export=module_manifest.exports[0], + ) + console.print(success_message_cli) # type: ignore + logger.log(level=LEVELS.get("SPAM"), msg=success_message_logger) # type: ignore + + +# custom_install("hub://guardrails/provenance_llm") +custom_install("hub://tryolabs/restricttotopic") +# custom_install("hub://guardrails/detect_pii") +# custom_install("hub://guardrails/competitor_check") +# custom_install("hub://guardrails/many_shot_jailbreak") diff --git a/custom-install/manifests/competitor-check.json b/custom-install/manifests/competitor-check.json new file mode 100644 index 0000000..b936fa2 --- /dev/null +++ b/custom-install/manifests/competitor-check.json @@ -0,0 +1,46 @@ +{ + "name": "Competitor Check", + "author": { + "name": "Guardrails AI", + "email": "contact@guardrailsai.com" + }, + "maintainers": [{ + "name": "Karan Acharya", + "email": "karan@guardrailsai.com" + }], + "repository": { + "url": "https://github.com/guardrails-ai/competitor_check.git", + "branch": "frontend_demo" + }, + "index": "./__init__.py", + "exports": [ + "CompetitorCheck" + ], + "tags": { + "language": [ + "en" + ], + "certification": [ + "Guardrails Certified" + ], + "contentType": [ + "string" + ], + "infrastructureRequirements": [ + "ML" + ], + "riskCategory": [ + "Brand risk" + ], + "useCases": [ + "Chatbots", + "Customer Support" + ] + }, + "id": "guardrails/competitor_check", + "namespace": "guardrails", + "packageName": "competitor_check", + "moduleName": "validator", + "requiresAuth": false, + "postInstall": "post-install.py" +} \ No newline at end of file diff --git a/custom-install/manifests/detect-pii.json b/custom-install/manifests/detect-pii.json new file mode 100644 index 0000000..c0c37ad --- /dev/null +++ b/custom-install/manifests/detect-pii.json @@ -0,0 +1,49 @@ +{ + "name": "Detect PII", + "author": { + "name": "Guardrails AI", + "email": "contact@guardrailsai.com" + }, + "maintainers": [{ + "name": "Caleb Courier", + "email": "caleb@guardrailsai.com" + }], + "repository": { + "url": "https://github.com/guardrails-ai/detect_pii.git", + "branch": "frontend_demo" + }, + "index": "./__init__.py", + "exports": [ + "DetectPII" + ], + "tags": { + "language": [ + "en" + ], + "certification": [ + "Guardrails Certified" + ], + "contentType": [ + "string" + ], + "infrastructureRequirements": [ + "ML" + ], + "riskCategory": [ + "Data Leakage" + ], + "useCases": [ + "Chatbots", + "RAG", + "CodeGen", + "Structured data", + "Customer Support" + ] + }, + "id": "guardrails/detect_pii", + "namespace": "guardrails", + "packageName": "detect_pii", + "moduleName": "validator", + "requiresAuth": false, + "postInstall": "post-install.py" +} \ No newline at end of file diff --git a/custom-install/manifests/jailbreak.json b/custom-install/manifests/jailbreak.json new file mode 100644 index 0000000..f461dc0 --- /dev/null +++ b/custom-install/manifests/jailbreak.json @@ -0,0 +1,49 @@ +{ + "name": "Detect Many Shot Jailbreak", + "author": { + "name": "Guardrails AI", + "email": "contact@guardrailsai.com" + }, + "maintainers": [{ + "name": "Wyatt Lansford", + "email": "wyatt@guardrailsai.com" + }], + "repository": { + "url": "https://github.com/guardrails-ai/jailbreak.git", + "branch": "main" + }, + "index": "./__init__.py", + "exports": [ + "DetectManyShotJailbreak" + ], + "tags": { + "language": [ + "en" + ], + "certification": [ + "Guardrails Certified" + ], + "contentType": [ + "string" + ], + "infrastructureRequirements": [ + "ML" + ], + "riskCategory": [ + "Data Leakage" + ], + "useCases": [ + "Chatbots", + "RAG", + "CodeGen", + "Structured data", + "Customer Support" + ] + }, + "id": "guardrails/many_shot_jailbreak", + "namespace": "guardrails", + "packageName": "many_shot_jailbreak", + "moduleName": "validator", + "requiresAuth": false, + "postInstall": "post-install.py" +} \ No newline at end of file diff --git a/custom-install/manifests/provenance-llm.json b/custom-install/manifests/provenance-llm.json new file mode 100644 index 0000000..aa9f950 --- /dev/null +++ b/custom-install/manifests/provenance-llm.json @@ -0,0 +1,49 @@ +{ + "name": "Provenance LLM", + "author": { + "name": "Guardrails AI", + "email": "contact@guardrailsai.com" + }, + "maintainers": [{ + "name": "Caleb Courier", + "email": "caleb@guardrailsai.com" + }], + "repository": { + "url": "https://github.com/guardrails-ai/provenance_llm.git", + "branch": "default-embed-func" + }, + "index": "./__init__.py", + "exports": [ + "ProvenanceLLM" + ], + "tags": { + "language": [ + "en" + ], + "certification": [ + "Guardrails Certified" + ], + "contentType": [ + "string" + ], + "infrastructureRequirements": [ + "ML", + "LLM" + ], + "riskCategory": [ + "Factuality", + "Brand risk" + ], + "useCases": [ + "Chatbots", + "RAG", + "Customer Support" + ] + }, + "id": "guardrails/provenance_llm", + "namespace": "guardrails", + "packageName": "provenance_llm", + "moduleName": "validator", + "requiresAuth": false, + "postInstall": "post-install.py" +} \ No newline at end of file diff --git a/custom-install/manifests/restrict-to-topic.json b/custom-install/manifests/restrict-to-topic.json new file mode 100644 index 0000000..3f55972 --- /dev/null +++ b/custom-install/manifests/restrict-to-topic.json @@ -0,0 +1,49 @@ +{ + "name": "Restrict to Topic", + "author": { + "name": "Tryolabs", + "email": "hello@tryolabs.com" + }, + "maintainers": [{ + "name": "Paz", + "email": "paz@tyrolabs.com" + }], + "repository": { + "url": "https://github.com/guardrails-ai/restricttotopic.git", + "branch": "streaming_demo" + }, + "index": "./__init__.py", + "exports": [ + "RestrictToTopic" + ], + "tags": { + "language": [ + "en" + ], + "certification": [ + "Guardrails Certified" + ], + "contentType": [ + "string" + ], + "infrastructureRequirements": [ + "LLM", + "ML" + ], + "riskCategory": [ + "Etiquette", + "Jailbreaking", + "Brand risk" + ], + "useCases": [ + "Chatbots", + "Customer Support" + ] + }, + "id": "tryolabs/restricttotopic", + "namespace": "tryolabs", + "packageName": "restricttotopic", + "moduleName": "validator", + "requiresAuth": false, + "postInstall": "post-install.py" +} \ No newline at end of file diff --git a/dev-build.sh b/dev-build.sh index 7bd9aae..fd4483e 100644 --- a/dev-build.sh +++ b/dev-build.sh @@ -1,3 +1,5 @@ +curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml + npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml docker build \ @@ -7,4 +9,4 @@ docker build \ --build-arg CACHEBUST="$(date)" \ --build-arg GITHUB_TOKEN="$GITHUB_TOKEN" \ --build-arg HF_TOKEN="$HF_TOKEN" \ - -t "guardrails-api:dev" .; \ No newline at end of file + -t "guardrails-api:dev" .; diff --git a/hub-requirements.txt b/hub-requirements.txt new file mode 100644 index 0000000..c117ad4 --- /dev/null +++ b/hub-requirements.txt @@ -0,0 +1,6 @@ +hub://guardrails/ends_with +hub://guardrails/extracted_summary_sentences_match +hub://guardrails/llm_critic +hub://guardrails/provenance_embeddings +hub://guardrails/valid_length +hub://guardrails/regex_match \ No newline at end of file diff --git a/lite-build.sh b/lite-build.sh new file mode 100755 index 0000000..9fb9508 --- /dev/null +++ b/lite-build.sh @@ -0,0 +1,12 @@ +curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml + +npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml + +docker build \ + -f Dockerfile.lite \ + --progress=plain \ + --no-cache \ + --build-arg CACHEBUST="$(date)" \ + --build-arg GITHUB_TOKEN="$GITHUB_TOKEN" \ + --build-arg HF_TOKEN="$HF_TOKEN" \ + -t "guardrails-api:lite" .; diff --git a/lite-run.sh b/lite-run.sh new file mode 100755 index 0000000..a6086a3 --- /dev/null +++ b/lite-run.sh @@ -0,0 +1,3 @@ +docker stop guardrails-api-lite || true +docker rm guardrails-api-lite || true +docker run -p 8000:8000 --env-file local.env --name guardrails-api-lite -it guardrails-api:lite \ No newline at end of file diff --git a/local.sh b/local.sh old mode 100644 new mode 100755 index c6af85f..f8a671a --- a/local.sh +++ b/local.sh @@ -13,23 +13,28 @@ export PGPASSWORD=${PGPASSWORD:-changeme} export PYTHONUNBUFFERED=1 export OTEL_PYTHON_TRACER_PROVIDER=sdk_tracer_provider export OTEL_SERVICE_NAME=guardrails-api -export OTEL_TRACES_EXPORTER=none # otlp #,console +export OTEL_TRACES_EXPORTER=otlp #,console export OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST="Accept-Encoding,User-Agent,Referer" export OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE="Last-Modified,Content-Type" export OTEL_METRICS_EXPORTER=none #otlp #,console -export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf +# export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf # export OTEL_EXPORTER_OTLP_ENDPOINT=https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com +export OTEL_EXPORTER_OTLP_PROTOCOL=grpc export OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317 +export OTEL_SDK_DISABLED=true + # export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/traces # export OTEL_EXPORTER_OTLP_METRICS_ENDPOINT=https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/metrics # export OTEL_EXPORTER_OTLP_LOGS_ENDPOINT=https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/logs -export LOGLEVEL=DEBUG +export LOGLEVEL="INFO" export GUARDRAILS_LOG_LEVEL="INFO" export GUARDRAILS_PROCESS_COUNT=1 export SELF_ENDPOINT=http://localhost:8000 export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES +export HF_API_KEY=${HF_TOKEN} + curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml @@ -42,7 +47,6 @@ npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json -# opentelemetry-instrument gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" # For running https locally # gunicorn --keyfile ~/certificates/local.key --certfile ~/certificates/local.cert --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" -opentelemetry-instrument gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" \ No newline at end of file +gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "app:create_app()" diff --git a/prod-build.sh b/prod-build.sh index 3dc1fc6..8ed2000 100644 --- a/prod-build.sh +++ b/prod-build.sh @@ -1,3 +1,5 @@ +curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml + # Dereference API Spec to JSON npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml @@ -13,4 +15,4 @@ docker build \ --progress=plain \ --build-arg CACHEBUST="$(date)" \ --no-cache \ - -t "guardrails-api:prod" .; \ No newline at end of file + -t "guardrails-api:prod" .; diff --git a/requirements-lock.txt b/requirements-lock.txt index 1928e01..5171469 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -1,89 +1,122 @@ -annotated-types==0.6.0 -anyio==4.3.0 +aiohttp==3.9.5 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.4.0 +asgiref==3.8.1 attrs==23.2.0 -backoff==2.2.1 -blinker==1.7.0 -boto3==1.34.66 -botocore==1.34.66 +blinker==1.8.2 +boto3==1.34.115 +botocore==1.34.115 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 colorama==0.4.6 coloredlogs==15.0.1 -cryptography==42.0.5 +cryptography==42.0.7 Deprecated==1.2.14 distro==1.9.0 faiss-cpu==1.8.0 -Flask==3.0.2 -Flask-Cors==4.0.0 +filelock==3.14.0 +Flask==3.0.3 +Flask-Cors==4.0.1 Flask-SQLAlchemy==3.1.1 +frozenlist==1.4.1 +fsspec==2024.5.0 googleapis-common-protos==1.63.0 griffe==0.36.9 -grpcio==1.62.1 -guardrails-ai==0.4.2 -gunicorn==21.2.0 +grpcio==1.64.0 +guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@fd77007dfe823f8cb32cd314b78e5f63aea71e9a +gunicorn==22.0.0 h11==0.14.0 -httpcore==1.0.4 +httpcore==1.0.5 httpx==0.27.0 +huggingface-hub==0.23.2 humanfriendly==10.0 -idna==3.6 -importlib-metadata==6.11.0 -itsdangerous==2.1.2 -Jinja2==3.1.3 +idna==3.7 +importlib-metadata==7.0.0 +itsdangerous==2.2.0 +Jinja2==3.1.4 jmespath==1.0.1 -joblib==1.3.2 +joblib==1.4.2 jsonpatch==1.33 jsonpointer==2.4 -jsonschema==4.21.1 +jsonschema==4.22.0 jsonschema-specifications==2023.12.1 jwt==1.3.1 -langchain-core==0.1.32 -langsmith==0.1.31 +langchain-core==0.1.52 +langsmith==0.1.65 +litellm==1.39.3 lxml==4.9.4 markdown-it-py==3.0.0 MarkupSafe==2.1.5 mdurl==0.1.2 +multidict==6.0.5 nltk==3.8.1 numpy==1.26.4 -openai==1.14.2 -opentelemetry-api==1.20.0 -opentelemetry-distro==0.41b0 -opentelemetry-exporter-otlp-proto-common==1.20.0 -opentelemetry-exporter-otlp-proto-grpc==1.20.0 -opentelemetry-exporter-otlp-proto-http==1.20.0 -opentelemetry-instrumentation==0.41b0 -opentelemetry-proto==1.20.0 -opentelemetry-sdk==1.20.0 -opentelemetry-semantic-conventions==0.41b0 -orjson==3.9.15 +openai==1.30.5 +opentelemetry-api==1.24.0 +opentelemetry-distro==0.45b0 +opentelemetry-exporter-otlp-proto-common==1.24.0 +opentelemetry-exporter-otlp-proto-grpc==1.24.0 +opentelemetry-exporter-otlp-proto-http==1.24.0 +opentelemetry-instrumentation==0.45b0 +opentelemetry-instrumentation-asgi==0.45b0 +opentelemetry-instrumentation-asyncio==0.45b0 +opentelemetry-instrumentation-aws-lambda==0.45b0 +opentelemetry-instrumentation-boto3sqs==0.45b0 +opentelemetry-instrumentation-botocore==0.45b0 +opentelemetry-instrumentation-dbapi==0.45b0 +opentelemetry-instrumentation-flask==0.45b0 +opentelemetry-instrumentation-grpc==0.45b0 +opentelemetry-instrumentation-httpx==0.45b0 +opentelemetry-instrumentation-jinja2==0.45b0 +opentelemetry-instrumentation-logging==0.45b0 +opentelemetry-instrumentation-requests==0.45b0 +opentelemetry-instrumentation-sqlalchemy==0.45b0 +opentelemetry-instrumentation-sqlite3==0.45b0 +opentelemetry-instrumentation-tortoiseorm==0.45b0 +opentelemetry-instrumentation-urllib==0.45b0 +opentelemetry-instrumentation-urllib3==0.45b0 +opentelemetry-instrumentation-wsgi==0.45b0 +opentelemetry-propagator-aws-xray==1.0.1 +opentelemetry-proto==1.24.0 +opentelemetry-sdk==1.24.0 +opentelemetry-semantic-conventions==0.45b0 +opentelemetry-test-utils==0.45b0 +opentelemetry-util-http==0.45b0 +orjson==3.10.3 packaging==23.2 protobuf==4.25.3 psycopg2-binary==2.9.9 -pycparser==2.21 -pydantic==2.6.4 -pydantic_core==2.16.3 +pycparser==2.22 +pydantic==2.7.2 +pydantic_core==2.18.3 pydash==7.0.7 -Pygments==2.17.2 +Pygments==2.18.0 python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 PyYAML==6.0.1 -referencing==0.34.0 +referencing==0.35.1 regex==2023.12.25 -requests==2.31.0 +requests==2.32.3 rich==13.7.1 -rpds-py==0.18.0 +rpds-py==0.18.1 rstr==3.2.2 s3transfer==0.10.1 +setuptools==70.0.0 shellingham==1.5.4 six==1.16.0 sniffio==1.3.1 -SQLAlchemy==2.0.28 -tenacity==8.2.3 -tiktoken==0.5.2 -tqdm==4.66.2 -typer==0.9.0 -typing_extensions==4.10.0 +SQLAlchemy==2.0.30 +tenacity==8.3.0 +tiktoken==0.7.0 +tokenizers==0.19.1 +tqdm==4.66.4 +typer==0.9.4 +typing_extensions==4.12.0 urllib3==2.2.1 -Werkzeug==3.0.1 +Werkzeug==3.0.3 wrapt==1.16.0 -zipp==3.18.1 +yarl==1.9.4 +zipp==3.19.0 diff --git a/requirements.txt b/requirements.txt index 1a94275..3eb72dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,11 @@ flask sqlalchemy lxml -guardrails-ai +guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@async_streaming # Let this come from guardrails-ai as a transient dependency. # Pip confuses tag versions with commit ids, # and claims a conflict even though it's the same thing. # guard-rails-api-client @ git+https://github.com/guardrails-ai/guardrails-api-client.git@v0.0.2#egg=guard-rails-api-client&subdirectory=guard-rails-api-client -opentelemetry-distro==0.41b0 -opentelemetry-sdk==1.20.0 -opentelemetry-exporter-otlp-proto-grpc==1.20.0 -opentelemetry-exporter-otlp-proto-http==1.20.0 flask_sqlalchemy werkzeug jsonschema @@ -20,4 +16,5 @@ faiss-cpu nltk boto3 gunicorn -psycopg2-binary \ No newline at end of file +psycopg2-binary +litellm \ No newline at end of file diff --git a/sample-config.py b/sample-config.py new file mode 100644 index 0000000..5667e9f --- /dev/null +++ b/sample-config.py @@ -0,0 +1,55 @@ +''' +All guards defined here will be initialized, if and only if +the application is using in memory guards. + +The application will use in memory guards if pg_host is left +undefined. Otherwise, a postgres instance will be started +and guards will be persisted into postgres. In that case, +these guards will not be initialized. +''' + +from guardrails import Guard +from guardrails.hub import RegexMatch, RestrictToTopic + +name_case = Guard( + name='name-case', + description='Checks that a string is in Name Case format.' +).use( + RegexMatch(regex="^[A-Z][a-z\\s]*$") +) + +all_caps = Guard( + name='all-caps', + description='Checks that a string is all capital.' +).use( + RegexMatch(regex="^[A-Z\\s]*$") +) + +valid_topics = ["music", "cooking", "camping", "outdoors"] +invalid_topics = ["sports", "work", "ai"] +all_topics = [*valid_topics, *invalid_topics] + +def custom_llm (text: str, *args, **kwargs): + return [ + { + "name": t, + "present": (t in text), + "confidence": 5 + } + for t in all_topics + ] + +custom_code_guard = Guard( + name='custom', + description='Uses a custom llm for RestrictToTopic' +).use( + RestrictToTopic( + valid_topics=valid_topics, + invalid_topics=invalid_topics, + llm_callable=custom_llm, + disable_classifier=True, + disable_llm=False, + # Pass this so it doesn't load the bart model + classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer", + ) +) \ No newline at end of file diff --git a/src/blueprints/guards.py b/src/blueprints/guards.py index fef6bb1..48af3d8 100644 --- a/src/blueprints/guards.py +++ b/src/blueprints/guards.py @@ -1,36 +1,63 @@ -import os import json +import os +from guardrails.hub import * # noqa from string import Template -from flask import Blueprint, request +from typing import Any, Dict, cast +from flask import Blueprint, Response, request, stream_with_context from urllib.parse import unquote_plus from guardrails import Guard from guardrails.classes import ValidationOutcome -from opentelemetry.trace import get_tracer +from opentelemetry.trace import Span from src.classes.guard_struct import GuardStruct from src.classes.http_error import HttpError from src.classes.validation_output import ValidationOutput -from src.clients.guard_client import GuardClient +from src.clients.memory_guard_client import MemoryGuardClient +from src.clients.pg_guard_client import PGGuardClient +from src.clients.postgres_client import postgres_is_enabled from src.utils.handle_error import handle_error -from src.utils.gather_request_metrics import gather_request_metrics from src.utils.get_llm_callable import get_llm_callable from src.utils.prep_environment import cleanup_environment, prep_environment guards_bp = Blueprint("guards", __name__, url_prefix="/guards") -guard_client = GuardClient() + + +# if no pg_host is set, use in memory guards +if postgres_is_enabled(): + guard_client = PGGuardClient() +else: + guard_client = MemoryGuardClient() + # read in guards from file + import config + + exports = config.__dir__() + for export_name in exports: + export = getattr(config, export_name) + is_guard = isinstance(export, Guard) + if is_guard: + guard_client.create_guard(export) @guards_bp.route("/", methods=["GET", "POST"]) @handle_error -@gather_request_metrics def guards(): if request.method == "GET": guards = guard_client.get_guards() + if len(guards) > 0 and (isinstance(guards[0], Guard)): + return [g._to_request() for g in guards] return [g.to_response() for g in guards] elif request.method == "POST": + if not postgres_is_enabled(): + raise HttpError( + 501, + "NotImplemented", + "POST /guards is not implemented for in-memory guards.", + ) payload = request.json guard = GuardStruct.from_request(payload) new_guard = guard_client.create_guard(guard) + if isinstance(new_guard, Guard): + return new_guard._to_request() return new_guard.to_response() else: raise HttpError( @@ -43,20 +70,45 @@ def guards(): @guards_bp.route("/", methods=["GET", "PUT", "DELETE"]) @handle_error -@gather_request_metrics def guard(guard_name: str): decoded_guard_name = unquote_plus(guard_name) if request.method == "GET": as_of_query = request.args.get("asOf") guard = guard_client.get_guard(decoded_guard_name, as_of_query) + if guard is None: + raise HttpError( + 404, + "NotFound", + "A Guard with the name {guard_name} does not exist!".format( + guard_name=decoded_guard_name + ), + ) + if isinstance(guard, Guard): + return guard._to_request() return guard.to_response() elif request.method == "PUT": + if not postgres_is_enabled(): + raise HttpError( + 501, + "NotImplemented", + "PUT / is not implemented for in-memory guards.", + ) payload = request.json guard = GuardStruct.from_request(payload) updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) + if isinstance(updated_guard, Guard): + return updated_guard._to_request() return updated_guard.to_response() elif request.method == "DELETE": + if not postgres_is_enabled(): + raise HttpError( + 501, + "NotImplemented", + "DELETE / is not implemented for in-memory guards.", + ) guard = guard_client.delete_guard(decoded_guard_name) + if isinstance(guard, Guard): + return guard._to_request() return guard.to_response() else: raise HttpError( @@ -67,9 +119,49 @@ def guard(guard_name: str): ) +def collect_telemetry( + *, + guard: Guard, + validate_span: Span, + validation_output: ValidationOutput, + prompt_params: Dict[str, Any], + result: ValidationOutcome, +): + # Below is all telemetry collection and + # should have no impact on what is returned to the user + prompt = guard.history.last.inputs.prompt + if prompt: + prompt = Template(prompt).safe_substitute(**prompt_params) + validate_span.set_attribute("prompt", prompt) + + instructions = guard.history.last.inputs.instructions + if instructions: + instructions = Template(instructions).safe_substitute(**prompt_params) + validate_span.set_attribute("instructions", instructions) + + validate_span.set_attribute("validation_status", guard.history.last.status) + validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) + + # Use the serialization from the class instead of re-writing it + valid_output: str = ( + json.dumps(validation_output.validated_output) + if isinstance(validation_output.validated_output, dict) + else str(validation_output.validated_output) + ) + validate_span.set_attribute("validated_output", valid_output) + + validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) + + num_of_reasks = ( + guard.history.last.iterations.length - 1 + if guard.history.last.iterations.length > 0 + else 0 + ) + validate_span.set_attribute("num_of_reasks", num_of_reasks) + + @guards_bp.route("//validate", methods=["POST"]) @handle_error -@gather_request_metrics def validate(guard_name: str): # Do we actually need a child span here? # We could probably use the existing span from the request unless we forsee @@ -82,107 +174,158 @@ def validate(guard_name: str): " {request_method}".format(request_method=request.method), ) payload = request.json - openai_api_key = request.headers.get("x-openai-api-key", None) + openai_api_key = request.headers.get( + "x-openai-api-key", os.environ.get("OPENAI_API_KEY") + ) decoded_guard_name = unquote_plus(guard_name) guard_struct = guard_client.get_guard(decoded_guard_name) - prep_environment(guard_struct) + if isinstance(guard_struct, GuardStruct): + # TODO: is there a way to do this with Guard? + prep_environment(guard_struct) llm_output = payload.pop("llmOutput", None) num_reasks = payload.pop("numReasks", guard_struct.num_reasks) prompt_params = payload.pop("promptParams", {}) llm_api = payload.pop("llmApi", None) args = payload.pop("args", []) + stream = payload.pop("stream", False) - service_name = os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") - otel_tracer = get_tracer(service_name) - - with otel_tracer.start_as_current_span( - f"validate-{decoded_guard_name}" - ) as validate_span: - guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer) - - validate_span.set_attribute("guardName", decoded_guard_name) - if llm_api is not None: - llm_api = get_llm_callable(llm_api) - if openai_api_key is None: - raise HttpError( - status=400, - message="BadRequest", - cause=( - "Cannot perform calls to OpenAI without an api key. Pass" - " openai_api_key when initializing the Guard or set the" - " OPENAI_API_KEY environment variable." - ), - ) - elif num_reasks > 1: + # service_name = os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") + # otel_tracer = get_tracer(service_name) + + payload["api_key"] = payload.get("api_key", openai_api_key) + + # with otel_tracer.start_as_current_span( + # f"validate-{decoded_guard_name}" + # ) as validate_span: + # guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer) + guard: Guard = Guard() + if isinstance(guard_struct, GuardStruct): + guard: Guard = guard_struct.to_guard(openai_api_key) + elif isinstance(guard_struct, Guard): + guard = guard_struct + # validate_span.set_attribute("guardName", decoded_guard_name) + if llm_api is not None: + llm_api = get_llm_callable(llm_api) + if openai_api_key is None: raise HttpError( status=400, message="BadRequest", cause=( - "Cannot perform re-asks without an LLM API. Specify llm_api when" - " calling guard(...)." + "Cannot perform calls to OpenAI without an api key. Pass" + " openai_api_key when initializing the Guard or set the" + " OPENAI_API_KEY environment variable." ), ) + elif num_reasks and num_reasks > 1: + raise HttpError( + status=400, + message="BadRequest", + cause=( + "Cannot perform re-asks without an LLM API. Specify llm_api when" + " calling guard(...)." + ), + ) - if llm_output is not None: - result: ValidationOutcome = guard.parse( - llm_output=llm_output, - num_reasks=num_reasks, - prompt_params=prompt_params, - llm_api=llm_api, - api_key=openai_api_key, - *args, - **payload, - ) - else: - result: ValidationOutcome = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - api_key=openai_api_key, - *args, - **payload, + if llm_output is not None: + if stream: + raise HttpError( + status=400, + message="BadRequest", + cause="Streaming is not supported for parse calls!", ) - # TODO: Just make this a ValidationOutcome with history - validation_output = ValidationOutput( - result.validation_passed, - result.validated_output, - guard.history, - result.raw_llm_output, + result: ValidationOutcome = guard.parse( + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + llm_api=llm_api, + # api_key=openai_api_key, + *args, + **payload, ) + else: + if stream: - prompt = guard.history.last.inputs.prompt - if prompt: - prompt = Template(prompt).safe_substitute(**prompt_params) - validate_span.set_attribute("prompt", prompt) + def guard_streamer(): + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + # api_key=openai_api_key, + *args, + **payload, + ) - instructions = guard.history.last.inputs.instructions - if instructions: - instructions = Template(instructions).safe_substitute(**prompt_params) - validate_span.set_attribute("instructions", instructions) + for result in guard_stream: + # TODO: Just make this a ValidationOutcome with history + validation_output: ValidationOutput = ValidationOutput( + result.validation_passed, + result.validated_output, + guard.history, + result.raw_llm_output, + ) - validate_span.set_attribute("validation_status", guard.history.last.status) - validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) + yield validation_output, cast(ValidationOutcome, result) - # Use the serialization from the class instead of re-writing it - valid_output: str = ( - json.dumps(validation_output.validated_output) - if isinstance(validation_output.validated_output, dict) - else str(validation_output.validated_output) - ) - validate_span.set_attribute("validated_output", valid_output) + def validate_streamer(guard_iter): + next_result = None + # next_validation_output = None + for validation_output, result in guard_iter: + next_result = result + # next_validation_output = validation_output + fragment = json.dumps(validation_output.to_response()) + yield f"{fragment}\n" - validate_span.set_attribute( - "tokens_consumed", guard.history.last.tokens_consumed - ) + final_validation_output: ValidationOutput = ValidationOutput( + next_result.validation_passed, + next_result.validated_output, + guard.history, + next_result.raw_llm_output, + ) + # I don't know if these are actually making it to OpenSearch + # because the span may be ended already + # collect_telemetry( + # guard=guard, + # validate_span=validate_span, + # validation_output=next_validation_output, + # prompt_params=prompt_params, + # result=next_result + # ) + final_output_json = json.dumps(final_validation_output.to_response()) + yield f"{final_output_json}\n" + + return Response( + stream_with_context(validate_streamer(guard_streamer())), + content_type="application/json", + # content_type="text/event-stream" + ) - num_of_reasks = ( - guard.history.last.iterations.length - 1 - if guard.history.last.iterations.length > 0 - else 0 + result: ValidationOutcome = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + # api_key=openai_api_key, + *args, + **payload, ) - validate_span.set_attribute("num_of_reasks", num_of_reasks) - cleanup_environment(guard_struct) + # TODO: Just make this a ValidationOutcome with history + validation_output = ValidationOutput( + result.validation_passed, + result.validated_output, + guard.history, + result.raw_llm_output, + ) + + # collect_telemetry( + # guard=guard, + # validate_span=validate_span, + # validation_output=validation_output, + # prompt_params=prompt_params, + # result=result + # ) + if isinstance(guard_struct, GuardStruct): + cleanup_environment(guard_struct) return validation_output.to_response() diff --git a/src/blueprints/root.py b/src/blueprints/root.py index 759a584..18a6cd9 100644 --- a/src/blueprints/root.py +++ b/src/blueprints/root.py @@ -5,9 +5,8 @@ from flask import Blueprint from sqlalchemy import text from src.classes.health_check import HealthCheck -from src.clients.postgres_client import PostgresClient +from src.clients.postgres_client import PostgresClient, postgres_is_enabled from src.utils.handle_error import handle_error -from src.utils.gather_request_metrics import gather_request_metrics from src.utils.logger import logger # from src.modules.otel_logger import logger @@ -18,15 +17,16 @@ @root_bp.route("/") @handle_error -@gather_request_metrics def home(): return "Hello, Flask!" @root_bp.route("/health-check") @handle_error -@gather_request_metrics def health_check(): + # If we're not using postgres, just return Ok + if not postgres_is_enabled(): + return HealthCheck(200, "Ok").to_dict() # Make sure we're connected to the database and can run queries pg_client = PostgresClient() query = text("SELECT count(datid) FROM pg_stat_activity;") @@ -42,7 +42,6 @@ def health_check(): @root_bp.route("/api-docs") @handle_error -@gather_request_metrics def api_docs(): global cached_api_spec if not cached_api_spec: @@ -53,7 +52,6 @@ def api_docs(): @root_bp.route("/docs") @handle_error -@gather_request_metrics def docs(): host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") swagger_ui = Template(""" diff --git a/src/classes/validation_output.py b/src/classes/validation_output.py index 827a3f1..69eda78 100644 --- a/src/classes/validation_output.py +++ b/src/classes/validation_output.py @@ -3,6 +3,8 @@ from guardrails.classes.history import Call from guardrails.utils.reask_utils import ReAsk +from src.utils.try_json_loads import try_json_loads + class ValidationOutput: def __init__( @@ -21,8 +23,14 @@ def __init__( "instructions": i.inputs.instructions.source if i.inputs.instructions is not None else None, - "output": i.outputs.raw_output - or i.outputs.llm_response_info.output, + "output": ( + i.outputs.raw_output + or ( + i.outputs.llm_response_info.output + if i.outputs.llm_response_info is not None + else None + ) + ), "parsedOutput": i.parsed_output, "prompt": { "source": i.inputs.prompt.source @@ -44,8 +52,12 @@ def __init__( # "metadata": fv.validation_result.metadata }, "valueAfterValidation": fv.value_after_validation, - "startTime": fv.start_time, - "endTime": fv.end_time, + "startTime": ( + fv.start_time.isoformat() if fv.start_time else None + ), + "endTime": ( + fv.end_time.isoformat() if fv.end_time else None + ), "instanceId": fv.instance_id, "propertyPath": fv.property_path, } @@ -58,6 +70,18 @@ def __init__( for c in calls ] self.raw_llm_response = raw_llm_response + self.validated_stream = [ + { + "chunk": raw_llm_response, + "validation_errors": [ + try_json_loads(fv.validation_result.error_message) + for fv in c.iterations.last.failed_validations + ] + if c.iterations.length > 0 + else [], + } + for c in calls + ] def to_response(self): return { @@ -65,4 +89,5 @@ def to_response(self): "validatedOutput": self.validated_output, "sessionHistory": self.session_history, "rawLlmResponse": self.raw_llm_response, + "validatedStream": self.validated_stream, } diff --git a/src/clients/guard_client.py b/src/clients/guard_client.py index 13815cd..806f2ee 100644 --- a/src/clients/guard_client.py +++ b/src/clients/guard_client.py @@ -1,94 +1,35 @@ -from typing import List +from typing import List, Union + +from guardrails import Guard from src.classes.guard_struct import GuardStruct -from src.classes.http_error import HttpError -from src.models.guard_item import GuardItem -from src.clients.postgres_client import PostgresClient -from src.models.guard_item_audit import GuardItemAudit class GuardClient: def __init__(self): self.initialized = True - self.pgClient = PostgresClient() - - def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: - latest_guard_item = ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) - audit_item = None - if as_of_date is not None: - audit_item = ( - self.pgClient.db.session.query(GuardItemAudit) - .filter_by(name=guard_name) - .filter(GuardItemAudit.replaced_on > as_of_date) - .order_by(GuardItemAudit.replaced_on.asc()) - .first() - ) - guard_item = audit_item if audit_item is not None else latest_guard_item - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause=f"A Guard with the name {guard_name} does not exist!", - ) - return GuardStruct.from_guard_item(guard_item) - - def get_guard_item(self, guard_name: str) -> GuardItem: - return ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) - def get_guards(self) -> List[GuardStruct]: - guard_items = self.pgClient.db.session.query(GuardItem).all() + def get_guard( + self, guard_name: str, as_of_date: str = None + ) -> Union[GuardStruct, Guard]: + raise NotImplementedError - return [GuardStruct.from_guard_item(gi) for gi in guard_items] + def get_guards(self) -> List[Union[GuardStruct, Guard]]: + raise NotImplementedError - def create_guard(self, guard: GuardStruct) -> GuardStruct: - guard_item = GuardItem( - name=guard.name, - railspec=guard.railspec.to_dict(), - num_reasks=guard.num_reasks, - description=guard.description, - ) - self.pgClient.db.session.add(guard_item) - self.pgClient.db.session.commit() - return GuardStruct.from_guard_item(guard_item) + def create_guard( + self, guard: Union[GuardStruct, Guard] + ) -> Union[GuardStruct, Guard]: + raise NotImplementedError - def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: - guard_item = self.get_guard_item(guard_name) - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause=f"A Guard with the name {guard_name} does not exist!", - ) - guard_item.railspec = guard.railspec.to_dict() - guard_item.num_reasks = guard.num_reasks - self.pgClient.db.session.commit() - return GuardStruct.from_guard_item(guard_item) + def update_guard( + self, guard_name: str, guard: Union[GuardStruct, Guard] + ) -> Union[GuardStruct, Guard]: + raise NotImplementedError - def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: - guard_item = self.get_guard_item(guard_name) - if guard_item is not None: - guard_item.railspec = guard.railspec.to_dict() - guard_item.num_reasks = guard.num_reasks - guard_item.description = guard.description - self.pgClient.db.session.commit() - return GuardStruct.from_guard_item(guard_item) - else: - return self.create_guard(guard) + def upsert_guard( + self, guard_name: str, guard: Union[GuardStruct, Guard] + ) -> Union[GuardStruct, Guard]: + raise NotImplementedError - def delete_guard(self, guard_name: str) -> GuardStruct: - guard_item = self.get_guard_item(guard_name) - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause="A Guard with the name {guard_name} does not exist!".format( - guard_name=guard_name - ), - ) - self.pgClient.db.session.delete(guard_item) - self.pgClient.db.session.commit() - guard = GuardStruct.from_guard_item(guard_item) - return guard + def delete_guard(self, guard_name: str) -> Union[GuardStruct, Guard]: + raise NotImplementedError diff --git a/src/clients/memory_guard_client.py b/src/clients/memory_guard_client.py new file mode 100644 index 0000000..a536b4d --- /dev/null +++ b/src/clients/memory_guard_client.py @@ -0,0 +1,54 @@ +from typing import List + +from guardrails import Guard +from src.classes.http_error import HttpError +from src.clients.guard_client import GuardClient + + +class MemoryGuardClient(GuardClient): + # key value pair of guard_name to guard + guards = {} + + def __init__(self): + self.initialized = True + + def get_guard(self, guard_name: str, as_of_date: str = None) -> Guard: + guard = self.guards.get(guard_name, None) + return guard + + def get_guards(self) -> List[Guard]: + return list(self.guards.values()) + + def create_guard(self, guard: Guard) -> Guard: + self.guards[guard.name] = guard + return guard + + def update_guard(self, guard_name: str, new_guard: Guard) -> Guard: + old_guard = self.get_guard(guard_name) + if old_guard is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + self.guards[guard_name] = new_guard + return new_guard + + def upsert_guard(self, guard_name: str, new_guard: Guard) -> Guard: + self.create_guard(new_guard) + return new_guard + + def delete_guard(self, guard_name: str) -> Guard: + deleted_guard = self.get_guard(guard_name) + if deleted_guard is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + del self.guards[guard_name] + return deleted_guard diff --git a/src/clients/pg_guard_client.py b/src/clients/pg_guard_client.py new file mode 100644 index 0000000..9684be1 --- /dev/null +++ b/src/clients/pg_guard_client.py @@ -0,0 +1,98 @@ +from typing import List +from src.classes.guard_struct import GuardStruct +from src.classes.http_error import HttpError +from src.clients.guard_client import GuardClient +from src.models.guard_item import GuardItem +from src.clients.postgres_client import PostgresClient +from src.models.guard_item_audit import GuardItemAudit + + +class PGGuardClient(GuardClient): + def __init__(self): + self.initialized = True + self.pgClient = PostgresClient() + + def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: + latest_guard_item = ( + self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() + ) + audit_item = None + if as_of_date is not None: + audit_item = ( + self.pgClient.db.session.query(GuardItemAudit) + .filter_by(name=guard_name) + .filter(GuardItemAudit.replaced_on > as_of_date) + .order_by(GuardItemAudit.replaced_on.asc()) + .first() + ) + guard_item = audit_item if audit_item is not None else latest_guard_item + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + return GuardStruct.from_guard_item(guard_item) + + def get_guard_item(self, guard_name: str) -> GuardItem: + return ( + self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() + ) + + def get_guards(self) -> List[GuardStruct]: + guard_items = self.pgClient.db.session.query(GuardItem).all() + + return [GuardStruct.from_guard_item(gi) for gi in guard_items] + + def create_guard(self, guard: GuardStruct) -> GuardStruct: + guard_item = GuardItem( + name=guard.name, + railspec=guard.railspec.to_dict(), + num_reasks=guard.num_reasks, + description=guard.description, + ) + self.pgClient.db.session.add(guard_item) + self.pgClient.db.session.commit() + return GuardStruct.from_guard_item(guard_item) + + def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + guard_item = self.get_guard_item(guard_name) + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + guard_item.railspec = guard.railspec.to_dict() + guard_item.num_reasks = guard.num_reasks + self.pgClient.db.session.commit() + return GuardStruct.from_guard_item(guard_item) + + def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + guard_item = self.get_guard_item(guard_name) + if guard_item is not None: + guard_item.railspec = guard.railspec.to_dict() + guard_item.num_reasks = guard.num_reasks + self.pgClient.db.session.commit() + return GuardStruct.from_guard_item(guard_item) + else: + return self.create_guard(guard) + + def delete_guard(self, guard_name: str) -> GuardStruct: + guard_item = self.get_guard_item(guard_name) + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + self.pgClient.db.session.delete(guard_item) + self.pgClient.db.session.commit() + guard = GuardStruct.from_guard_item(guard_item) + return guard diff --git a/src/clients/postgres_client.py b/src/clients/postgres_client.py index 6a868c1..ed814d7 100644 --- a/src/clients/postgres_client.py +++ b/src/clients/postgres_client.py @@ -7,6 +7,10 @@ from src.models.base import db, INIT_EXTENSIONS +def postgres_is_enabled() -> bool: + return os.environ.get("PGHOST", None) is not None + + class PostgresClient: _instance = None diff --git a/src/otel/__init__.py b/src/otel/__init__.py new file mode 100644 index 0000000..e44072b --- /dev/null +++ b/src/otel/__init__.py @@ -0,0 +1,28 @@ +import os +from src.otel.logs import logs_are_disabled +from src.otel.metrics import ( + initialize_metrics_collector, + metrics_are_disabled, + get_meter, # noqa +) +from src.otel.traces import ( + traces_are_disabled, + initialize_tracer, + get_tracer, # noqa +) + + +def otel_is_disabled() -> bool: + sdk_is_disabled = os.environ.get("OTEL_SDK_DISABLED") == "true" + + all_signals_disabled = ( + traces_are_disabled() and metrics_are_disabled() and logs_are_disabled() + ) + return sdk_is_disabled or all_signals_disabled + + +def initialize(): + initialize_tracer() + initialize_metrics_collector() + # Logs are supported yet in the Python SDK + # initialize_logs_collector() diff --git a/src/otel/constants.py b/src/otel/constants.py new file mode 100644 index 0000000..c63514a --- /dev/null +++ b/src/otel/constants.py @@ -0,0 +1 @@ +none = "none" diff --git a/src/otel/logs.py b/src/otel/logs.py new file mode 100644 index 0000000..4418b2f --- /dev/null +++ b/src/otel/logs.py @@ -0,0 +1,7 @@ +import os +from src.otel.constants import none + + +def logs_are_disabled() -> bool: + otel_logs_exporter = os.environ.get("OTEL_LOGS_EXPORTER", none) + return otel_logs_exporter == none diff --git a/src/otel/metrics.py b/src/otel/metrics.py new file mode 100644 index 0000000..a009566 --- /dev/null +++ b/src/otel/metrics.py @@ -0,0 +1,59 @@ +import os +from typing import Optional +from opentelemetry import metrics +from opentelemetry.metrics import Meter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import ( + ConsoleMetricExporter, + PeriodicExportingMetricReader, + MetricExporter, +) +from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( + OTLPMetricExporter as HttpMetricExporter, +) +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter as GrpcMetricExporter, +) +from src.otel.constants import none + + +def metrics_are_disabled() -> bool: + otel_metrics_exporter = os.environ.get("OTEL_METRICS_EXPORTER", none) + return otel_metrics_exporter == none + + +def get_meter(name: Optional[str] = None) -> Meter: + meter_name = name or os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") + meter = metrics.get_meter(meter_name) + + return meter + + +def get_metrics_exporter(exporter_type: str) -> MetricExporter: + if exporter_type == "otlp": + otlp_protocol = os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + metrics_exporter = HttpMetricExporter() + if otlp_protocol == "grpc": + metrics_exporter = GrpcMetricExporter() + return metrics_exporter + elif exporter_type == "console": + return ConsoleMetricExporter() + + +def initialize_metrics_collector(): + if not metrics_are_disabled(): + metrics_exporter_settings = os.environ.get( + "OTEL_METRICS_EXPORTER", "none" + ).split(",") + metric_exporters = [ + get_metrics_exporter(e) for e in metrics_exporter_settings if e != "none" + ] + + metric_readers = [] + for exporter in metric_exporters: + metric_readers.append(PeriodicExportingMetricReader(exporter)) + + provider = MeterProvider(metric_readers=metric_readers) + metrics.set_meter_provider(provider) + + get_meter() diff --git a/src/otel/traces.py b/src/otel/traces.py new file mode 100644 index 0000000..6257d23 --- /dev/null +++ b/src/otel/traces.py @@ -0,0 +1,72 @@ +import os +from typing import Optional +from opentelemetry import trace +from opentelemetry.trace import Tracer +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + SimpleSpanProcessor, + ConsoleSpanExporter, + SpanExporter, + SpanProcessor, +) +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter as HttpSpanExporter, +) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter as GrpcSpanExporter, +) +from src.otel.constants import none + + +def traces_are_disabled() -> bool: + otel_traces_exporter = os.environ.get("OTEL_TRACES_EXPORTER", none) + return otel_traces_exporter == none + + +def get_tracer(name: Optional[str] = None) -> Tracer: + tracer_name = name or os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") + tracer = trace.get_tracer(tracer_name) + + return tracer + + +def get_span_exporter(exporter_type: str) -> SpanExporter: + if exporter_type == "otlp": + otlp_protocol = os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + trace_exporter = HttpSpanExporter() + if otlp_protocol == "grpc": + trace_exporter = GrpcSpanExporter() + return trace_exporter + elif exporter_type == "console": + return ConsoleSpanExporter() + + +def set_span_processors( + tracer_provider: TracerProvider, + exporter: SpanExporter, + use_batch: bool, +) -> SpanProcessor: + span_processor = BatchSpanProcessor(exporter) + if not use_batch: + span_processor = SimpleSpanProcessor(exporter) + tracer_provider.add_span_processor(span_processor) + + +def initialize_tracer(): + if not traces_are_disabled(): + tracer_provider = trace.get_tracer_provider() + + trace_exporter_settings = os.environ.get("OTEL_TRACES_EXPORTER", "none").split( + "," + ) + trace_exporters = [ + get_span_exporter(e) for e in trace_exporter_settings if e != "none" + ] + + use_batch = os.environ.get("OTEL_PROCESS_IN_BATCH", "true") == "true" + for exporter in trace_exporters: + set_span_processors(tracer_provider, exporter, use_batch) + + # Initialize singleton + get_tracer() diff --git a/src/utils/get_llm_callable.py b/src/utils/get_llm_callable.py index bc68728..3c12a05 100644 --- a/src/utils/get_llm_callable.py +++ b/src/utils/get_llm_callable.py @@ -1,3 +1,4 @@ +import litellm from typing import Any, Awaitable, Callable, Union from guardrails.utils.openai_utils import ( get_static_openai_create_func, @@ -14,27 +15,27 @@ def get_llm_callable( llm_api: str, ) -> Union[Callable, Callable[[Any], Awaitable[Any]]]: try: + model = ValidatePayloadLlmApi(llm_api) # TODO: Add error handling and throw 400 if ( - ValidatePayloadLlmApi(llm_api) - is ValidatePayloadLlmApi.OPENAI_COMPLETION_CREATE + model is ValidatePayloadLlmApi.OPENAI_COMPLETION_CREATE + or model is ValidatePayloadLlmApi.OPENAI_COMPLETIONS_CREATE ): return get_static_openai_create_func() elif ( - ValidatePayloadLlmApi(llm_api) - is ValidatePayloadLlmApi.OPENAI_CHATCOMPLETION_CREATE + model is ValidatePayloadLlmApi.OPENAI_CHATCOMPLETION_CREATE + or model is ValidatePayloadLlmApi.OPENAI_CHAT_COMPLETIONS_CREATE ): return get_static_openai_chat_create_func() - elif ( - ValidatePayloadLlmApi(llm_api) - is ValidatePayloadLlmApi.OPENAI_COMPLETION_ACREATE - ): + elif model is ValidatePayloadLlmApi.OPENAI_COMPLETION_ACREATE: return get_static_openai_acreate_func() - elif ( - ValidatePayloadLlmApi(llm_api) - is ValidatePayloadLlmApi.OPENAI_CHATCOMPLETION_ACREATE - ): + elif model is ValidatePayloadLlmApi.OPENAI_CHATCOMPLETION_ACREATE: return get_static_openai_chat_acreate_func() + elif model is ValidatePayloadLlmApi.LITELLM_COMPLETION: + return litellm.completion + elif model is ValidatePayloadLlmApi.LITELLM_ACOMPLETION: + return litellm.acompletion + else: pass except Exception: diff --git a/src/utils/handle_error.py b/src/utils/handle_error.py index b21796c..a47bfb1 100644 --- a/src/utils/handle_error.py +++ b/src/utils/handle_error.py @@ -15,6 +15,8 @@ def decorator(*args, **kwargs): traceback.print_exception(http_error) return http_error.to_dict(), http_error.status except HTTPException as http_exception: + logger.error(http_exception) + traceback.print_exception(http_exception) http_error = HttpError(http_exception.code, http_exception.description) return http_error.to_dict(), http_error.status except Exception as e: diff --git a/src/utils/try_json_loads.py b/src/utils/try_json_loads.py new file mode 100644 index 0000000..ad50ffb --- /dev/null +++ b/src/utils/try_json_loads.py @@ -0,0 +1,9 @@ +import json + + +def try_json_loads(val): + try: + string_val = json.loads(val, default=str) + return string_val + except Exception: + return val diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..8784b77 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,826 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "some-token\n", + "http://localhost:8000\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "os.environ[\"GUARDRAILS_API_KEY\"] = \"some-token\"\n", + "os.environ[\"GUARDRAILS_BASE_URL\"] = \"http://localhost:8000\"\n", + "\n", + "\n", + "print(os.environ.get(\"GUARDRAILS_API_KEY\"))\n", + "print(os.environ.get(\"GUARDRAILS_BASE_URL\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/calebcourier/Projects/gr-mono/guardrails-cdk/guardrails-api/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "prov-test\n" + ] + } + ], + "source": [ + "from guardrails import Guard\n", + "from guardrails.hub import ProvenanceLLM\n", + "\n", + "SOURCES = [\n", + " \"The sun is a star.\",\n", + " \"The sun rises in the east and sets in the west.\",\n", + " \"Sun is the largest object in the solar system, and all planets revolve around it.\",\n", + "]\n", + "\n", + "guard = Guard(name=\"prov-test\").use(\n", + " ProvenanceLLM(),\n", + " stream=True\n", + ")\n", + "\n", + "print(guard.name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
ValidationOutcome(raw_llm_output='The', validated_output='The', reask=None, validation_passed=True, error=None)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The'\u001b[0m, \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The'\u001b[0m, \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m, \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m, \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun',\n",
+       "    validated_output='The sun',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is',\n",
+       "    validated_output='The sun is',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a',\n",
+       "    validated_output='The sun is a',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star',\n",
+       "    validated_output='The sun is a star',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at',\n",
+       "    validated_output='The sun is a star at',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the',\n",
+       "    validated_output='The sun is a star at the',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center',\n",
+       "    validated_output='The sun is a star at the center',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of',\n",
+       "    validated_output='The sun is a star at the center of',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our',\n",
+       "    validated_output='The sun is a star at the center of our',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar',\n",
+       "    validated_output='The sun is a star at the center of our solar',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system',\n",
+       "    validated_output='The sun is a star at the center of our solar system',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that',\n",
+       "    validated_output='The sun is a star at the center of our solar system that',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to Earth',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to Earth',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
last message: \n",
+       "ValidationOutcome(\n",
+       "    raw_llm_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    validated_output='The sun is a star at the center of our solar system that provides light and heat to Earth.',\n",
+       "    reask=None,\n",
+       "    validation_passed=True,\n",
+       "    error=None\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "last message: \n", + "\u001b[1;35mValidationOutcome\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mraw_llm_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[32m'The sun is a star at the center of our solar system that provides light and heat to Earth.'\u001b[0m,\n", + " \u001b[33mreask\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mvalidation_passed\u001b[0m=\u001b[3;92mTrue\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import litellm\n", + "from rich import print\n", + "\n", + "response = guard(\n", + " llm_api=litellm.completion,\n", + " model=\"gpt-3.5-turbo\",\n", + " instructions=\"You are a helpful assistant.\",\n", + " prompt=\"Write a short and accurate statement about the sun.\",\n", + " metadata={\"sources\": SOURCES},\n", + " stream=True\n", + ")\n", + "\n", + "fragment_count = 0\n", + "last_message = None\n", + "for message in response:\n", + " fragment_count += 1\n", + " print(message)\n", + " last_message = message\n", + " \n", + "print(\"last message: \", last_message)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Call(\n",
+       "    iterations=[\n",
+       "        Iteration(\n",
+       "            inputs=Inputs(\n",
+       "                llm_api=None,\n",
+       "                llm_output=None,\n",
+       "                instructions=Instructions(You are a helpful assistant.),\n",
+       "                prompt=Prompt(Write a short and accurate statement about the sun...),\n",
+       "                msg_history=None,\n",
+       "                prompt_params={},\n",
+       "                num_reasks=1,\n",
+       "                metadata={\n",
+       "                    'sources': [\n",
+       "                        'The sun is a star.',\n",
+       "                        'The sun rises in the east and sets in the west.',\n",
+       "                        'Sun is the largest object in the solar system, and all planets revolve around it.'\n",
+       "                    ]\n",
+       "                },\n",
+       "                full_schema_reask=False\n",
+       "            ),\n",
+       "            outputs=Outputs(\n",
+       "                llm_response_info=LLMResponse(\n",
+       "                    prompt_token_count=None,\n",
+       "                    response_token_count=None,\n",
+       "                    output='The sun is a star located at the center of our solar system, providing light and heat \n",
+       "to Earth.',\n",
+       "                    stream_output=None\n",
+       "                ),\n",
+       "                raw_output='The sun is a star located at the center of our solar system, providing light and heat \n",
+       "to Earth.',\n",
+       "                parsed_output='The sun is a star located at the center of our solar system, providing light and \n",
+       "heat to Earth.',\n",
+       "                validation_response=None,\n",
+       "                guarded_output=None,\n",
+       "                reasks=[],\n",
+       "                validator_logs=[],\n",
+       "                error=None,\n",
+       "                exception=None\n",
+       "            )\n",
+       "        )\n",
+       "    ],\n",
+       "    inputs=CallInputs(\n",
+       "        llm_api=<function completion at 0x16800a8e0>,\n",
+       "        llm_output=None,\n",
+       "        instructions='You are a helpful assistant.',\n",
+       "        prompt='Write a short and accurate statement about the sun.',\n",
+       "        msg_history=None,\n",
+       "        prompt_params={},\n",
+       "        num_reasks=1,\n",
+       "        metadata={\n",
+       "            'sources': [\n",
+       "                'The sun is a star.',\n",
+       "                'The sun rises in the east and sets in the west.',\n",
+       "                'Sun is the largest object in the solar system, and all planets revolve around it.'\n",
+       "            ]\n",
+       "        },\n",
+       "        full_schema_reask=False,\n",
+       "        args=[],\n",
+       "        kwargs={'model': 'gpt-3.5-turbo', 'stream': True}\n",
+       "    )\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mCall\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33miterations\u001b[0m=\u001b[1m[\u001b[0m\n", + " \u001b[1;35mIteration\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33minputs\u001b[0m=\u001b[1;35mInputs\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mllm_api\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mllm_output\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33minstructions\u001b[0m=\u001b[1;35mInstructions\u001b[0m\u001b[1m(\u001b[0mYou are a helpful assistant.\u001b[1m)\u001b[0m,\n", + " \u001b[33mprompt\u001b[0m=\u001b[1;35mPrompt\u001b[0m\u001b[1m(\u001b[0mWrite a short and accurate statement about the sun\u001b[33m...\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[33mmsg_history\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mprompt_params\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[33mnum_reasks\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mmetadata\u001b[0m=\u001b[1m{\u001b[0m\n", + " \u001b[32m'sources'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[32m'The sun is a star.'\u001b[0m,\n", + " \u001b[32m'The sun rises in the east and sets in the west.'\u001b[0m,\n", + " \u001b[32m'Sun is the largest object in the solar system, and all planets revolve around it.'\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[33mfull_schema_reask\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + " \u001b[1m)\u001b[0m,\n", + " \u001b[33moutputs\u001b[0m=\u001b[1;35mOutputs\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mllm_response_info\u001b[0m=\u001b[1;35mLLMResponse\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mprompt_token_count\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mresponse_token_count\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33moutput\u001b[0m=\u001b[32m'The sun is a star located at the center of our solar system, providing light and heat \u001b[0m\n", + "\u001b[32mto Earth.'\u001b[0m,\n", + " \u001b[33mstream_output\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + " \u001b[1m)\u001b[0m,\n", + " \u001b[33mraw_output\u001b[0m=\u001b[32m'The sun is a star located at the center of our solar system, providing light and heat \u001b[0m\n", + "\u001b[32mto Earth.'\u001b[0m,\n", + " \u001b[33mparsed_output\u001b[0m=\u001b[32m'The sun is a star located at the center of our solar system, providing light and \u001b[0m\n", + "\u001b[32mheat to Earth.'\u001b[0m,\n", + " \u001b[33mvalidation_response\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mguarded_output\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mreasks\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mvalidator_logs\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33merror\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mexception\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + " \u001b[1m)\u001b[0m\n", + " \u001b[1m)\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[33minputs\u001b[0m=\u001b[1;35mCallInputs\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mllm_api\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mfunction\u001b[0m\u001b[39m completion at \u001b[0m\u001b[1;36m0x16800a8e0\u001b[0m\u001b[1m>\u001b[0m,\n", + " \u001b[33mllm_output\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33minstructions\u001b[0m=\u001b[32m'You are a helpful assistant.'\u001b[0m,\n", + " \u001b[33mprompt\u001b[0m=\u001b[32m'Write a short and accurate statement about the sun.'\u001b[0m,\n", + " \u001b[33mmsg_history\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + " \u001b[33mprompt_params\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[33mnum_reasks\u001b[0m=\u001b[1;36m1\u001b[0m,\n", + " \u001b[33mmetadata\u001b[0m=\u001b[1m{\u001b[0m\n", + " \u001b[32m'sources'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[32m'The sun is a star.'\u001b[0m,\n", + " \u001b[32m'The sun rises in the east and sets in the west.'\u001b[0m,\n", + " \u001b[32m'Sun is the largest object in the solar system, and all planets revolve around it.'\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[33mfull_schema_reask\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + " \u001b[33margs\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mkwargs\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'model'\u001b[0m: \u001b[32m'gpt-3.5-turbo'\u001b[0m, \u001b[32m'stream'\u001b[0m: \u001b[3;92mTrue\u001b[0m\u001b[1m}\u001b[0m\n", + " \u001b[1m)\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(guard.history.last)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/blueprints/test_guards.py b/tests/blueprints/test_guards.py index 154fa2b..666fd72 100644 --- a/tests/blueprints/test_guards.py +++ b/tests/blueprints/test_guards.py @@ -1,12 +1,28 @@ -from unittest.mock import PropertyMock, call +import os +from unittest.mock import PropertyMock from typing import Dict, Tuple + +import pytest + from tests.mocks.mock_blueprint import MockBlueprint from tests.mocks.mock_guard_client import MockGuardStruct from tests.mocks.mock_request import MockRequest from guardrails.classes import ValidationOutcome from guardrails.classes.generic import Stack -from guardrails.classes.history import Call, CallInputs -from tests.mocks.mock_trace import MockTracer +from guardrails.classes.history import Call +# from tests.mocks.mock_trace import MockTracer + + +@pytest.fixture(autouse=True) +def around_each(): + # Code that will run before the test + openai_api_key_bak = os.environ.get("OPENAI_API_KEY") + if openai_api_key_bak: + del os.environ["OPENAI_API_KEY"] + yield + # Code that will run after the test + if openai_api_key_bak: + os.environ["OPENAI_API_KEY"] = openai_api_key_bak def test_route_setup(mocker): @@ -25,9 +41,16 @@ def test_guards__get(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mock_get_guards = mocker.patch( - "src.blueprints.guards.GuardClient.get_guards", return_value=[mock_guard] + "src.blueprints.guards.guard_client.get_guards", return_value=[mock_guard] ) - mocker.patch("src.blueprints.guards.get_tracer") + mocker.patch("src.blueprints.guards.collect_telemetry") + + # >>> Conflict + # mock_get_guards = mocker.patch( + # "src.blueprints.guards.guard_client.get_guards", return_value=[mock_guard] + # ) + # mocker.patch("src.blueprints.guards.get_tracer") + from src.blueprints.guards import guards response = guards() @@ -37,7 +60,8 @@ def test_guards__get(mocker): assert response == [{"name": "mock-guard"}] -def test_guards__post(mocker): +def test_guards__post_pg(mocker): + os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() mock_request = MockRequest("POST", mock_guard.to_response()) @@ -47,25 +71,43 @@ def test_guards__post(mocker): "src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard ) mock_create_guard = mocker.patch( - "src.blueprints.guards.GuardClient.create_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.create_guard", return_value=mock_guard ) - mocker.patch("src.blueprints.guards.get_tracer") + from src.blueprints.guards import guards response = guards() - assert mock_from_request.called_once_with(mock_guard) - assert mock_create_guard.called_once_with(mock_guard) + mock_from_request.assert_called_once_with(mock_guard.to_response()) + mock_create_guard.assert_called_once_with(mock_guard) assert response == {"name": "mock-guard"} + del os.environ["PGHOST"] + + +def test_guards__post_mem(mocker): + mock_guard = MockGuardStruct() + mock_request = MockRequest("POST", mock_guard.to_response()) + + mocker.patch("flask.Blueprint", new=MockBlueprint) + mocker.patch("src.blueprints.guards.request", mock_request) + + from src.blueprints.guards import guards + + response = guards() + + error_body, status = response + + assert status == 501 + def test_guards__raises(mocker): mock_request = MockRequest("PUT") mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import guards @@ -84,63 +126,94 @@ def test_guards__raises(mocker): assert status == 405 -def test_guard__get(mocker): +def test_guard__get_mem(mocker): mock_guard = MockGuardStruct() timestamp = "2024-03-04T14:11:42-06:00" mock_request = MockRequest("GET", args={"asOf": timestamp}) mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) + mock_get_guard = mocker.patch( - "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard ) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") + + # >>> Conflict + # mock_get_guard = mocker.patch( + # "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard + # ) + # mocker.patch("src.blueprints.guards.get_tracer") + from src.blueprints.guards import guard response = guard("My%20Guard's%20Name") - assert mock_get_guard.called_once_with("My Guard's Name", timestamp) + mock_get_guard.assert_called_once_with("My Guard's Name", timestamp) assert response == {"name": "mock-guard"} -def test_guard__put(mocker): +def test_guard__put_pg(mocker): + os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() mock_request = MockRequest("PUT", json={"name": "mock-guard"}) mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) + mock_from_request = mocker.patch( "src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard ) mock_upsert_guard = mocker.patch( - "src.blueprints.guards.GuardClient.upsert_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.upsert_guard", return_value=mock_guard ) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") + + # >>> Conflict + # mock_from_request = mocker.patch( + # "src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard + # ) + # mock_upsert_guard = mocker.patch( + # "src.blueprints.guards.guard_client.upsert_guard", return_value=mock_guard + # ) + # mocker.patch("src.blueprints.guards.get_tracer") + from src.blueprints.guards import guard response = guard("My%20Guard's%20Name") - assert mock_from_request.called_once_with(mock_guard) - assert mock_upsert_guard.called_once_with("My Guard's Name", mock_guard) + mock_from_request.assert_called_once_with(mock_guard.to_response()) + mock_upsert_guard.assert_called_once_with("My Guard's Name", mock_guard) assert response == {"name": "mock-guard"} + del os.environ["PGHOST"] -def test_guard__delete(mocker): +def test_guard__delete_pg(mocker): + os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() mock_request = MockRequest("DELETE") mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) + mock_delete_guard = mocker.patch( - "src.blueprints.guards.GuardClient.delete_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.delete_guard", return_value=mock_guard ) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") + + # >>> Conflict + # mock_delete_guard = mocker.patch( + # "src.blueprints.guards.guard_client.delete_guard", return_value=mock_guard + # ) + # mocker.patch("src.blueprints.guards.get_tracer") + from src.blueprints.guards import guard response = guard("my-guard-name") - assert mock_delete_guard.called_once_with("my-guard-name") + mock_delete_guard.assert_called_once_with("my-guard-name") assert response == {"name": "mock-guard"} + del os.environ["PGHOST"] def test_guard__raises(mocker): @@ -148,7 +221,7 @@ def test_guard__raises(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import guard @@ -172,7 +245,7 @@ def test_validate__raises_method_not_allowed(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mocker.patch("src.blueprints.guards.get_tracer") + # mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate @@ -192,17 +265,18 @@ def test_validate__raises_method_not_allowed(mocker): def test_validate__raises_bad_request__openai_api_key(mocker): + os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() - mock_tracer = MockTracer() + # mock_tracer = MockTracer() mock_request = MockRequest("POST", json={"llmApi": "bar"}) mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mock_get_guard = mocker.patch( - "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") - mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate @@ -210,7 +284,7 @@ def test_validate__raises_bad_request__openai_api_key(mocker): response = validate("My%20Guard's%20Name") assert mock_prep_environment.call_count == 1 - assert mock_get_guard.called_once_with("My Guard's Name") + mock_get_guard.assert_called_once_with("My Guard's Name") assert isinstance(response, Tuple) error, status = response @@ -223,20 +297,22 @@ def test_validate__raises_bad_request__openai_api_key(mocker): " OPENAI_API_KEY environment variable." ) assert status == 400 + del os.environ["PGHOST"] def test_validate__raises_bad_request__num_reasks(mocker): + os.environ["PGHOST"] = "localhost" mock_guard = MockGuardStruct() - mock_tracer = MockTracer() + # mock_tracer = MockTracer() mock_request = MockRequest("POST", json={"numReasks": 3}) mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mock_get_guard = mocker.patch( - "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") - mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate @@ -244,7 +320,7 @@ def test_validate__raises_bad_request__num_reasks(mocker): response = validate("My%20Guard's%20Name") assert mock_prep_environment.call_count == 1 - assert mock_get_guard.called_once_with("My Guard's Name") + mock_get_guard.assert_called_once_with("My Guard's Name") assert isinstance(response, Tuple) error, status = response @@ -256,9 +332,11 @@ def test_validate__raises_bad_request__num_reasks(mocker): " calling guard(...)." ) assert status == 400 + del os.environ["PGHOST"] def test_validate__parse(mocker): + os.environ["PGHOST"] = "localhost" mock_parse = mocker.patch.object(MockGuardStruct, "parse") mock_parse.return_value = ValidationOutcome( raw_llm_output="Hello world!", @@ -266,7 +344,7 @@ def test_validate__parse(mocker): validation_passed=True, ) mock_guard = MockGuardStruct() - mock_tracer = MockTracer() + # mock_tracer = MockTracer() mock_request = MockRequest( "POST", json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, @@ -275,50 +353,56 @@ def test_validate__parse(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mock_get_guard = mocker.patch( - "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mock_cleanup_environment = mocker.patch("src.blueprints.guards.cleanup_environment") - mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) - set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + + # >>> Conflict + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + + # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") mock_status = mocker.patch( "guardrails.classes.history.call.Call.status", new_callable=PropertyMock ) mock_status.return_value = "pass" - mock_guard.history = Stack(Call(inputs=CallInputs(prompt="Hello world prompt!"))) + mock_guard.history = Stack(Call()) from src.blueprints.guards import validate response = validate("My%20Guard's%20Name") assert mock_prep_environment.call_count == 1 - assert mock_get_guard.called_once_with("My Guard's Name") + mock_get_guard.assert_called_once_with("My Guard's Name") assert mock_parse.call_count == 1 - assert mock_parse.called_once_with( + mock_parse.assert_called_once_with( 1, 2, 3, llm_output="Hello world!", num_reasks=0, - prompt_params=None, + prompt_params={}, llm_api=None, some_kwarg="foo", + api_key=None, ) - assert set_attribute_spy.call_count == 7 - expected_calls = [ - call("guardName", "My Guard's Name"), - call("prompt", "Hello world prompt!"), - call("validation_status", "pass"), - call("raw_llm_ouput", "Hello world!"), - call("validated_output", "Hello world!"), - call("tokens_consumed", None), - call("num_of_reasks", 0), - ] - set_attribute_spy.assert_has_calls(expected_calls) + # Temporarily Disabled + # assert set_attribute_spy.call_count == 7 + # expected_calls = [ + # call("guardName", "My Guard's Name"), + # call("prompt", "Hello world prompt!"), + # call("validation_status", "pass"), + # call("raw_llm_ouput", "Hello world!"), + # call("validated_output", "Hello world!"), + # call("tokens_consumed", None), + # call("num_of_reasks", 0), + # ] + # set_attribute_spy.assert_has_calls(expected_calls) assert mock_cleanup_environment.call_count == 1 @@ -327,16 +411,20 @@ def test_validate__parse(mocker): "validatedOutput": "Hello world!", "sessionHistory": [{"history": []}], "rawLlmResponse": "Hello world!", + "validatedStream": [{"chunk": "Hello world!", "validation_errors": []}], } + del os.environ["PGHOST"] + def test_validate__call(mocker): + os.environ["PGHOST"] = "localhost" mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") mock___call__.return_value = ValidationOutcome( raw_llm_output="Hello world!", validated_output=None, validation_passed=False ) mock_guard = MockGuardStruct() - mock_tracer = MockTracer() + # mock_tracer = MockTracer() mock_request = MockRequest( "POST", json={ @@ -351,35 +439,37 @@ def test_validate__call(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mock_get_guard = mocker.patch( - "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + "src.blueprints.guards.guard_client.get_guard", return_value=mock_guard ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mock_cleanup_environment = mocker.patch("src.blueprints.guards.cleanup_environment") - mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + mocker.patch( + "src.blueprints.guards.get_llm_callable", + return_value="openai.Completion.create", + ) - set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + + # >>> Conflict + # mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) + + # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") mock_status = mocker.patch( "guardrails.classes.history.call.Call.status", new_callable=PropertyMock ) mock_status.return_value = "fail" - mock_guard.history = Stack( - Call( - inputs=CallInputs( - prompt="Hello world prompt!", instructions="Hello world instructions!" - ) - ) - ) + mock_guard.history = Stack(Call()) from src.blueprints.guards import validate response = validate("My%20Guard's%20Name") assert mock_prep_environment.call_count == 1 - assert mock_get_guard.called_once_with("My Guard's Name") + mock_get_guard.assert_called_once_with("My Guard's Name") assert mock___call__.call_count == 1 - assert mock___call__.called_once_with( + mock___call__.assert_called_once_with( 1, 2, 3, @@ -387,20 +477,22 @@ def test_validate__call(mocker): prompt_params={"p1": "bar"}, num_reasks=0, some_kwarg="foo", + api_key="mock-key", ) - assert set_attribute_spy.call_count == 8 - expected_calls = [ - call("guardName", "My Guard's Name"), - call("prompt", "Hello world prompt!"), - call("instructions", "Hello world instructions!"), - call("validation_status", "fail"), - call("raw_llm_ouput", "Hello world!"), - call("validated_output", "None"), - call("tokens_consumed", None), - call("num_of_reasks", 0), - ] - set_attribute_spy.assert_has_calls(expected_calls) + # Temporarily Disabled + # assert set_attribute_spy.call_count == 8 + # expected_calls = [ + # call("guardName", "My Guard's Name"), + # call("prompt", "Hello world prompt!"), + # call("instructions", "Hello world instructions!"), + # call("validation_status", "fail"), + # call("raw_llm_ouput", "Hello world!"), + # call("validated_output", "None"), + # call("tokens_consumed", None), + # call("num_of_reasks", 0), + # ] + # set_attribute_spy.assert_has_calls(expected_calls) assert mock_cleanup_environment.call_count == 1 @@ -409,4 +501,7 @@ def test_validate__call(mocker): "validatedOutput": None, "sessionHistory": [{"history": []}], "rawLlmResponse": "Hello world!", + "validatedStream": [{"chunk": "Hello world!", "validation_errors": []}], } + + del os.environ["PGHOST"] diff --git a/tests/blueprints/test_root.py b/tests/blueprints/test_root.py index e0a85fe..3754748 100644 --- a/tests/blueprints/test_root.py +++ b/tests/blueprints/test_root.py @@ -1,3 +1,4 @@ +import os from src.utils.logger import logger from tests.mocks.mock_blueprint import MockBlueprint from tests.mocks.mock_postgres_client import MockPostgresClient @@ -17,6 +18,7 @@ def test_home(mocker): def test_health_check(mocker): + os.environ["PGHOST"] = "localhost" mocker.patch("flask.Blueprint", new=MockBlueprint) mock_pg = MockPostgresClient() @@ -34,10 +36,11 @@ def text_side_effect(query: str): response = health_check() - assert mock_text.called_once_with("SELECT count(datid) FROM pg_stat_activity;") + mock_text.assert_called_once_with("SELECT count(datid) FROM pg_stat_activity;") assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] info_spy.assert_called_once_with("response: %s", [(1,)]) assert response == {"status": 200, "message": "Ok"} mocker.resetall() + del os.environ["PGHOST"] diff --git a/tests/clients/test_mem_guard_client.py b/tests/clients/test_mem_guard_client.py new file mode 100644 index 0000000..ed2bd77 --- /dev/null +++ b/tests/clients/test_mem_guard_client.py @@ -0,0 +1,51 @@ +# from src.clients.memory_guard_client import MemoryGuardClient +from tests.mocks.mock_guard_client import MockGuardStruct + + +def test_init(mocker): + from src.clients.memory_guard_client import MemoryGuardClient + + mem_guard_client = MemoryGuardClient() + + assert mem_guard_client.initialized is True + + +class TestGetGuard: + def test_get_all(self, mocker): + from src.clients.memory_guard_client import MemoryGuardClient + + guard_client = MemoryGuardClient() + + result = guard_client.get_guards() + + assert result == [] + + def test_get_all_after_insert(self, mocker): + from src.clients.memory_guard_client import MemoryGuardClient + + guard_client = MemoryGuardClient() + new_guard = MockGuardStruct() + guard_client.create_guard(new_guard) + result = guard_client.get_guards() + + assert result == [new_guard] + + def test_get_guard_after_insert(self, mocker): + from src.clients.memory_guard_client import MemoryGuardClient + + guard_client = MemoryGuardClient() + new_guard = MockGuardStruct("test_guard") + guard_client.create_guard(new_guard) + result = guard_client.get_guard("test_guard") + + assert result == new_guard + + def test_not_found(self, mocker): + from src.clients.memory_guard_client import MemoryGuardClient + + guard_client = MemoryGuardClient() + new_guard = MockGuardStruct("test_guard") + guard_client.create_guard(new_guard) + result = guard_client.get_guard("guard_that_does_not_exist") + + assert result is None diff --git a/tests/clients/test_guard_client.py b/tests/clients/test_pg_guard_client.py similarity index 72% rename from tests/clients/test_guard_client.py rename to tests/clients/test_pg_guard_client.py index 660027f..d220dd0 100644 --- a/tests/clients/test_guard_client.py +++ b/tests/clients/test_pg_guard_client.py @@ -1,9 +1,10 @@ import pytest from unittest.mock import ANY as AnyMatcher from src.classes.http_error import HttpError + +# from src.clients.memory_guard_client import MemoryGuardClient from src.models.guard_item import GuardItem from src.models.guard_item_audit import GuardItemAudit -from src.clients.guard_client import GuardClient from tests.mocks.mock_postgres_client import MockPostgresClient from tests.mocks.mock_guard_client import MockGuardStruct, MockRailspec from unittest.mock import call @@ -11,20 +12,25 @@ def test_init(mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) + mocker.patch( + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client + ) + + from src.clients.pg_guard_client import PGGuardClient - guard_client = GuardClient() + pg_guard_client = PGGuardClient() + # mem_guard_client = MemoryGuardClient() - assert guard_client.initialized is True - assert isinstance(guard_client.pgClient, MockPostgresClient) - assert guard_client.pgClient == mock_pg_client + assert pg_guard_client.initialized is True + assert isinstance(pg_guard_client.pgClient, MockPostgresClient) + assert pg_guard_client.pgClient == mock_pg_client class TestGetGuard: def test_get_latest(self, mocker): mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) query_spy = mocker.spy(mock_pg_client.db.session, "query") @@ -34,25 +40,27 @@ def test_get_latest(self, mocker): mock_first.return_value = latest_guard mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = latest_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.get_guard("guard") query_spy.assert_called_once_with(GuardItem) filter_by_spy.assert_called_once_with(name="guard") assert mock_first.call_count == 1 - assert mock_from_guard_item.called_once_with(latest_guard) + mock_from_guard_item.assert_called_once_with(latest_guard) assert result == latest_guard def test_with_as_of_date(self, mocker): mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) query_spy = mocker.spy(mock_pg_client.db.session, "query") @@ -65,11 +73,13 @@ def test_with_as_of_date(self, mocker): mock_first.side_effect = [latest_guard, previous_guard] mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = previous_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.get_guard("guard", as_of_date="2024-03-06") @@ -90,23 +100,25 @@ def test_with_as_of_date(self, mocker): assert replaced_on_order_exp.compare(order_by_spy_call) assert mock_first.call_count == 2 - assert mock_from_guard_item.called_once_with(previous_guard) + mock_from_guard_item.assert_called_once_with(previous_guard) assert result == previous_guard def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_first = mocker.patch.object(mock_pg_client.db.session, "first") mock_first.return_value = None mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() with pytest.raises(HttpError) as exc_info: guard_client.get_guard("guard") @@ -122,7 +134,9 @@ def test_raises_not_found(self, mocker): def test_get_guard_item(mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) + mocker.patch( + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client + ) query_spy = mocker.spy(mock_pg_client.db.session, "query") filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") @@ -130,7 +144,9 @@ def test_get_guard_item(mocker): latest_guard = MockGuardStruct("latest") mock_first.return_value = latest_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.get_guard_item("guard") @@ -143,7 +159,9 @@ def test_get_guard_item(mocker): def test_get_guards(mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) + mocker.patch( + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client + ) query_spy = mocker.spy(mock_pg_client.db.session, "query") mock_all = mocker.patch.object(mock_pg_client.db.session, "all") @@ -153,11 +171,13 @@ def test_get_guards(mocker): mock_all.return_value = guards mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.side_effect = [guard_one, guard_two] - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.get_guards() @@ -175,18 +195,22 @@ def test_create_guard(mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mock_guard_struct_init_spy = mocker.spy(MockGuardStruct, "__init__") - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mocker.patch("src.clients.guard_client.GuardItem", new=MockGuardStruct) + mocker.patch( + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client + ) + mocker.patch("src.clients.pg_guard_client.GuardItem", new=MockGuardStruct) add_spy = mocker.spy(mock_pg_client.db.session, "add") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = mock_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.create_guard(mock_guard) @@ -215,19 +239,21 @@ def test_raises_not_found(self, mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = None commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() with pytest.raises(HttpError) as exc_info: guard_client.update_guard("mock-guard", mock_guard) @@ -247,21 +273,23 @@ def test_updates_guard_item(self, mocker): updated_guard = MockGuardStruct(num_reasks=2) mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard to_dict_spy = mocker.spy(updated_guard.railspec, "to_dict") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = updated_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.update_guard("mock-guard", updated_guard) @@ -283,30 +311,32 @@ def test_guard_doesnt_exist_yet(self, mocker): new_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = None commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_create_guard = mocker.patch( - "src.clients.guard_client.GuardClient.create_guard" + "src.clients.pg_guard_client.PGGuardClient.create_guard" ) mock_create_guard.return_value = new_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.upsert_guard("mock-guard", input_guard) - assert mock_get_guard_item.called_once_with("mock-guard") + mock_get_guard_item.assert_called_once_with("mock-guard") assert commit_spy.call_count == 0 assert mock_from_guard_item.call_count == 0 - assert mock_create_guard.called_once_with(input_guard) + mock_create_guard.assert_called_once_with(input_guard) assert result == new_guard @@ -315,21 +345,23 @@ def test_guard_already_exists(self, mocker): updated_guard = MockGuardStruct(num_reasks=2, description="updated description") mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard to_dict_spy = mocker.spy(updated_guard.railspec, "to_dict") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = updated_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.upsert_guard("mock-guard", updated_guard) @@ -349,19 +381,21 @@ class TestDeleteGuard: def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = None commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() with pytest.raises(HttpError) as exc_info: guard_client.delete_guard("mock-guard") @@ -380,21 +414,23 @@ def test_deletes_guard_item(self, mocker): old_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mocker.patch( - "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + "src.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client ) mock_get_guard_item = mocker.patch( - "src.clients.guard_client.GuardClient.get_guard_item" + "src.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard delete_spy = mocker.spy(mock_pg_client.db.session, "delete") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( - "src.clients.guard_client.GuardStruct.from_guard_item" + "src.clients.pg_guard_client.GuardStruct.from_guard_item" ) mock_from_guard_item.return_value = old_guard - guard_client = GuardClient() + from src.clients.pg_guard_client import PGGuardClient + + guard_client = PGGuardClient() result = guard_client.delete_guard("mock-guard") diff --git a/tests/mocks/mock_guard_client.py b/tests/mocks/mock_guard_client.py index 5ffe424..214146e 100644 --- a/tests/mocks/mock_guard_client.py +++ b/tests/mocks/mock_guard_client.py @@ -1,9 +1,12 @@ +from src.classes.guard_struct import GuardStruct + + class MockRailspec: def to_dict(self, *args, **kwargs): return {} -class MockGuardStruct: +class MockGuardStruct(GuardStruct): name: str description: str num_reasks: int