diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f6ed06922..45891ea37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,7 @@ on: branches: - main - dev + - feat/* # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -49,8 +50,8 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - pydantic-version: ["1.10.9", "2.4.2"] - openai-version: ["1.30.1"] + pydantic-version: ["==1.10.9", ">=2.x"] + openai-version: [">=1.x"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -71,8 +72,8 @@ jobs: # TODO: fix errors so that we can run `make dev` instead run: | make full - poetry run pip install pydantic==${{ matrix.pydantic-version }} - poetry run pip install openai==${{ matrix.openai-version }} + poetry run pip install "pydantic${{ matrix.pydantic-version }}" + poetry run pip install "openai${{ matrix.openai-version }}" - name: Static analysis with pyright run: | @@ -86,8 +87,8 @@ jobs: # TODO: fix errors so that we can run both `make dev` and `make full` # dependencies: ['dev', 'full'] dependencies: ["full"] - pydantic-version: ["1.10.9", "2.4.2"] - openai-version: ["1.30.1"] + pydantic-version: ["==1.10.9", ">=2.x"] + openai-version: [">=1.x"] steps: - uses: actions/checkout@v4 @@ -113,8 +114,8 @@ jobs: - name: Install Dependencies run: | make ${{ matrix.dependencies }} - poetry run pip install pydantic==${{ matrix.pydantic-version }} - poetry run pip install openai==${{ matrix.openai-version }} + poetry run pip install "pydantic${{ matrix.pydantic-version }}" + poetry run pip install "openai${{ matrix.openai-version }}" - name: Run Pytests run: | diff --git a/.github/workflows/examples_check.yml b/.github/workflows/examples_check.yml index 9be015077..a7eab18b0 100644 --- a/.github/workflows/examples_check.yml +++ b/.github/workflows/examples_check.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: # this line is automatically generated by the script in .github/workflows/scripts/update_notebook_matrix.sh - notebook: ["bug_free_python_code.ipynb","check_for_pii.ipynb","competitors_check.ipynb","extracting_entities.ipynb","generate_structured_data.ipynb","generate_structured_data_cohere.ipynb","guardrails_with_chat_models.ipynb","input_validation.ipynb","llamaindex-output-parsing.ipynb","no_secrets_in_generated_text.ipynb","provenance.ipynb","recipe_generation.ipynb","regex_validation.ipynb","response_is_on_topic.ipynb","secrets_detection.ipynb","select_choice_based_on_action.ipynb","streaming.ipynb","syntax_error_free_sql.ipynb","text_summarization_quality.ipynb","toxic_language.ipynb","translation_to_specific_language.ipynb","translation_with_quality_check.ipynb","valid_chess_moves.ipynb","value_within_distribution.ipynb"] + notebook: ["bug_free_python_code.ipynb","check_for_pii.ipynb","competitors_check.ipynb","extracting_entities.ipynb","generate_structured_data.ipynb","generate_structured_data_cohere.ipynb","guardrails_with_chat_models.ipynb","input_validation.ipynb","llamaindex-output-parsing.ipynb","no_secrets_in_generated_text.ipynb","provenance.ipynb","recipe_generation.ipynb","regex_validation.ipynb","response_is_on_topic.ipynb","secrets_detection.ipynb","select_choice_based_on_action.ipynb","syntax_error_free_sql.ipynb","text_summarization_quality.ipynb","toxic_language.ipynb","translation_to_specific_language.ipynb","translation_with_quality_check.ipynb","valid_chess_moves.ipynb","value_within_distribution.ipynb"] env: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -24,6 +24,9 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + fetch-depth: 0 - name: Create .guardrailsrc run: | echo 'id="SYSTEM TESTING"' > ~/.guardrailsrc @@ -32,31 +35,30 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.11.x - # - name: Poetry cache - # uses: actions/cache@v3 - # with: - # path: ~/.cache/pypoetry - # key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }} - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - name: Install dependencies run: | - make full; - poetry add "openai>=1.2.4" jupyter nbconvert cohere==5.3.2; - - name: Check for pypdfium2 - run: poetry run pip show pypdfium2 + + # Setup Virtual Environment + python3 -m venv ./.venv + source .venv/bin/activate + + # Install the current branch + pip install . + + # Install extra stuff for notebook runs + pip install "huggingface_hub[cli]" jupyter nbconvert cohere==5.3.2 + pip install nltk - name: Huggingface Hub Login - run: poetry run huggingface-cli login --token $HUGGINGFACE_API_KEY + run: | + source .venv/bin/activate + huggingface-cli login --token $HUGGINGFACE_API_KEY - name: download nltk data run: | + source .venv/bin/activate mkdir /tmp/nltk_data; - poetry run python -m nltk.downloader -d /tmp/nltk_data punkt; - - name: Use venv - run: source .venv/bin/activate + python -m nltk.downloader -d /tmp/nltk_data punkt; - name: Execute notebooks and check for errors - run: bash ./.github/workflows/scripts/run_notebooks.sh ${{ matrix.notebook }} + run: | + source .venv/bin/activate + bash ./.github/workflows/scripts/run_notebooks.sh ${{ matrix.notebook }} diff --git a/.github/workflows/install_from_hub.yml b/.github/workflows/install_from_hub.yml index c39cc8502..1e6483547 100644 --- a/.github/workflows/install_from_hub.yml +++ b/.github/workflows/install_from_hub.yml @@ -1,4 +1,4 @@ -name: Notebook Execution and Error Check +name: Install from Hub on: push: diff --git a/.github/workflows/scripts/run_notebooks.sh b/.github/workflows/scripts/run_notebooks.sh index e32932948..cd854ceb6 100755 --- a/.github/workflows/scripts/run_notebooks.sh +++ b/.github/workflows/scripts/run_notebooks.sh @@ -1,6 +1,10 @@ #!/bin/bash export NLTK_DATA=/tmp/nltk_data; +# Remove the local guardrails directory and use the installed version +rm -rf ./guardrails + +# Navigate to notebooks cd docs/examples # Get the notebook name from the matrix variable @@ -10,10 +14,12 @@ notebook="$1" invalid_notebooks=("valid_chess_moves.ipynb" "llamaindex-output-parsing.ipynb" "competitors_check.ipynb") if [[ ! " ${invalid_notebooks[@]} " =~ " ${notebook} " ]]; then echo "Processing $notebook..." - poetry run jupyter nbconvert --to notebook --execute "$notebook" + # poetry run jupyter nbconvert --to notebook --execute "$notebook" + jupyter nbconvert --to notebook --execute "$notebook" if [ $? -ne 0 ]; then echo "Error found in $notebook" echo "Error in $notebook. See logs for details." >> errors.txt + exit 1 fi fi diff --git a/Makefile b/Makefile index ee919a60b..59f3d0fc5 100644 --- a/Makefile +++ b/Makefile @@ -84,4 +84,4 @@ precommit: # pytest -x -q --no-summary pyright guardrails/ make lint - ./github/workflows/scripts/update_notebook_matrix.sh + ./.github/workflows/scripts/update_notebook_matrix.sh diff --git a/docs/examples/bug_free_python_code.ipynb b/docs/examples/bug_free_python_code.ipynb index 340ca3bda..d377b4eb5 100644 --- a/docs/examples/bug_free_python_code.ipynb +++ b/docs/examples/bug_free_python_code.ipynb @@ -30,15 +30,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Note: you may need to restart the kernel to use updated packages.\n", - "Note: you may need to restart the kernel to use updated packages.\n", "Installing hub:\u001b[35m/\u001b[0m\u001b[35m/reflex/\u001b[0m\u001b[95mvalid_python...\u001b[0m\n", "✅Successfully installed reflex/valid_python!\n", "\n", @@ -47,15 +45,12 @@ } ], "source": [ - "%pip install guardrails-ai -q\n", - "%pip install pydantic -q\n", - "\n", "!guardrails hub install hub://reflex/valid_python --quiet" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -93,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -128,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -209,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -239,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -284,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -356,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -398,7 +393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.11.9" }, "orig_nbformat": 4, "vscode": { diff --git a/docs/examples/generate_structured_data_cohere.ipynb b/docs/examples/generate_structured_data_cohere.ipynb index 35286463b..e22624ea8 100644 --- a/docs/examples/generate_structured_data_cohere.ipynb +++ b/docs/examples/generate_structured_data_cohere.ipynb @@ -10,14 +10,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "346e1b5c", "metadata": {}, "outputs": [], "source": [ "prompt = \"\"\"\n", "Generate a dataset of fake user orders. Each row of the dataset should be valid. The format should not be a list, it should be a JSON object.\n", - "${gr.complete_json_suffix}\n", + "${gr.complete_xml_suffix}\n", "\n", "an example of output may look like this:\n", "{\n", @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "6a7c7d4a", "metadata": {}, "outputs": [ @@ -49,103 +49,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-03-25 16:13:45\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96111]\u001b[0m \u001b[1;30mNOTICE\u001b[0m \u001b[1;36mInstalling hub://guardrails/valid_length...\u001b[0m\n", - " Running command git clone --filter=blob:none --quiet https://github.com/guardrails-ai/valid_length.git /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-y3acxka5\n", - "\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/valid_length/validator already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/valid_length/valid_length-0.0.0.dist-info already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[32m2024-03-25 16:13:52\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96111]\u001b[0m \u001b[1;30mINFO\u001b[0m Collecting git+https://github.com/guardrails-ai/valid_length.git\n", - " Cloning https://github.com/guardrails-ai/valid_length.git to /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-y3acxka5\n", - " Resolved https://github.com/guardrails-ai/valid_length.git to commit 4b59a5fc1ae2106a585881784e3c2086f1fe8b9b\n", - " Installing build dependencies: started\n", - " Installing build dependencies: finished with status 'done'\n", - " Getting requirements to build wheel: started\n", - " Getting requirements to build wheel: finished with status 'done'\n", - " Installing backend dependencies: started\n", - " Installing backend dependencies: finished with status 'done'\n", - " Preparing metadata (pyproject.toml): started\n", - " Preparing metadata (pyproject.toml): finished with status 'done'\n", - "Building wheels for collected packages: valid_length\n", - " Building wheel for valid_length (pyproject.toml): started\n", - " Building wheel for valid_length (pyproject.toml): finished with status 'done'\n", - " Created wheel for valid_length: filename=valid_length-0.0.0-py3-none-any.whl size=12348 sha256=98e297c72fa6bc34b9c52e2ce3b87365ce925225b0792d1a72096baf11b4e792\n", - " Stored in directory: /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-ephem-wheel-cache-odhptidc/wheels/48/9f/75/34a76a1e575dafaf9df180a2074f698d77193d5d3670823f69\n", - "Successfully built valid_length\n", - "Installing collected packages: valid_length\n", - "Successfully installed valid_length-0.0.0\n", - "\n", - "\u001b[32m2024-03-25 16:13:52\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96111]\u001b[0m \u001b[1;30mSUCCESS\u001b[0m \u001b[1;32m\n", - "\n", - " Successfully installed guardrails/valid_length!\n", - "\n", - " See how to use it here: https://hub.guardrailsai.com/validator/guardrails/valid_length\n", - " \u001b[0m\n", - "\u001b[32m2024-03-25 16:13:53\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96229]\u001b[0m \u001b[1;30mNOTICE\u001b[0m \u001b[1;36mInstalling hub://guardrails/two_words...\u001b[0m\n", - " Running command git clone --filter=blob:none --quiet https://github.com/guardrails-ai/two_words.git /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-c9bdjnpk\n", - "\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/two_words/validator already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/two_words/two_words-0.0.0.dist-info already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[32m2024-03-25 16:13:59\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96229]\u001b[0m \u001b[1;30mINFO\u001b[0m Collecting git+https://github.com/guardrails-ai/two_words.git\n", - " Cloning https://github.com/guardrails-ai/two_words.git to /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-c9bdjnpk\n", - " Resolved https://github.com/guardrails-ai/two_words.git to commit e7f682c0b8d45a9407e966028d72682cd909601e\n", - " Installing build dependencies: started\n", - " Installing build dependencies: finished with status 'done'\n", - " Getting requirements to build wheel: started\n", - " Getting requirements to build wheel: finished with status 'done'\n", - " Installing backend dependencies: started\n", - " Installing backend dependencies: finished with status 'done'\n", - " Preparing metadata (pyproject.toml): started\n", - " Preparing metadata (pyproject.toml): finished with status 'done'\n", - "Building wheels for collected packages: two_words\n", - " Building wheel for two_words (pyproject.toml): started\n", - " Building wheel for two_words (pyproject.toml): finished with status 'done'\n", - " Created wheel for two_words: filename=two_words-0.0.0-py3-none-any.whl size=11227 sha256=a10ae6f93738a3223ec28db3712fdc288547689235606516c589caff1f84889c\n", - " Stored in directory: /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-ephem-wheel-cache-b_l3rvkp/wheels/36/68/76/f184dbc7d9cea0daec56ec1394537018f2ddeb660f9ad79ce6\n", - "Successfully built two_words\n", - "Installing collected packages: two_words\n", - "Successfully installed two_words-0.0.0\n", - "\n", - "\u001b[32m2024-03-25 16:14:00\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96229]\u001b[0m \u001b[1;30mSUCCESS\u001b[0m \u001b[1;32m\n", + "Installing hub:\u001b[35m/\u001b[0m\u001b[35m/guardrails/\u001b[0m\u001b[95mvalid_length...\u001b[0m\n", + "✅Successfully installed guardrails/valid_length!\n", "\n", - " Successfully installed guardrails/two_words!\n", "\n", - " See how to use it here: https://hub.guardrailsai.com/validator/guardrails/two_words\n", - " \u001b[0m\n", - "\u001b[32m2024-03-25 16:14:01\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96313]\u001b[0m \u001b[1;30mNOTICE\u001b[0m \u001b[1;36mInstalling hub://guardrails/valid_range...\u001b[0m\n", - " Running command git clone --filter=blob:none --quiet https://github.com/guardrails-ai/valid_range.git /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-_ndib8nc\n", - "\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/valid_range/validator already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Target directory /Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/hub/guardrails/valid_range/valid_range-0.0.0.dist-info already exists. Specify --upgrade to force replacement.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[32m2024-03-25 16:14:07\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96313]\u001b[0m \u001b[1;30mINFO\u001b[0m Collecting git+https://github.com/guardrails-ai/valid_range.git\n", - " Cloning https://github.com/guardrails-ai/valid_range.git to /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-req-build-_ndib8nc\n", - " Resolved https://github.com/guardrails-ai/valid_range.git to commit d01ad21d73d753ad224fd395bce18428196951c5\n", - " Installing build dependencies: started\n", - " Installing build dependencies: finished with status 'done'\n", - " Getting requirements to build wheel: started\n", - " Getting requirements to build wheel: finished with status 'done'\n", - " Installing backend dependencies: started\n", - " Installing backend dependencies: finished with status 'done'\n", - " Preparing metadata (pyproject.toml): started\n", - " Preparing metadata (pyproject.toml): finished with status 'done'\n", - "Building wheels for collected packages: valid_range\n", - " Building wheel for valid_range (pyproject.toml): started\n", - " Building wheel for valid_range (pyproject.toml): finished with status 'done'\n", - " Created wheel for valid_range: filename=valid_range-0.0.0-py3-none-any.whl size=11575 sha256=ca6ffe0537e7a64c74332398596451c01395a6fa8086f98e514d0a1cfad5fff9\n", - " Stored in directory: /private/var/folders/c8/jqt82fpx785dpwpp36ljkgm40000gn/T/pip-ephem-wheel-cache-7zrcqsov/wheels/0c/66/c0/f9ea25da535775c4ffca5bbd385863945a6397fd1863f6abe8\n", - "Successfully built valid_range\n", - "Installing collected packages: valid_range\n", - "Successfully installed valid_range-0.0.0\n", + "Installing hub:\u001b[35m/\u001b[0m\u001b[35m/guardrails/\u001b[0m\u001b[95mtwo_words...\u001b[0m\n", + "✅Successfully installed guardrails/two_words!\n", "\n", - "\u001b[32m2024-03-25 16:14:07\u001b[0m \u001b[35mzmac\u001b[0m \u001b[34mguardrails-cli[96313]\u001b[0m \u001b[1;30mSUCCESS\u001b[0m \u001b[1;32m\n", "\n", - " Successfully installed guardrails/valid_range!\n", + "Installing hub:\u001b[35m/\u001b[0m\u001b[35m/guardrails/\u001b[0m\u001b[95mvalid_range...\u001b[0m\n", + "✅Successfully installed guardrails/valid_range!\n", "\n", - " See how to use it here: https://hub.guardrailsai.com/validator/guardrails/valid_range\n", - " \u001b[0m\n" + "\n" ] } ], "source": [ - "!guardrails hub install hub://guardrails/valid_length\n", - "!guardrails hub install hub://guardrails/two_words\n", - "!guardrails hub install hub://guardrails/valid_range" + "!guardrails hub install hub://guardrails/valid_length --quiet\n", + "!guardrails hub install hub://guardrails/two_words --quiet\n", + "!guardrails hub install hub://guardrails/valid_range --quiet\n", + "!pip install cohere==5.3.2 --quiet" ] }, { @@ -158,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "3088fd99", "metadata": {}, "outputs": [], @@ -198,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "840006ca-21ca-4f76-9ce1-e406d5d68412", "metadata": {}, "outputs": [], @@ -221,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "42766922-14d0-4b5e-853a-23f05b896a09", "metadata": {}, "outputs": [ @@ -229,22 +152,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/validatorsattr.py:307: UserWarning: Validator 1-indexed is not installed!\n", - " warnings.warn(f\"Validator {validator_name} is not installed!\")\n", - "/Users/zaydsimjee/workspace/guardrails/.venv/lib/python3.11/site-packages/guardrails/validators/__init__.py:50: FutureWarning: \n", - " Importing validators from `guardrails.validators` is deprecated.\n", - " All validators are now available in the Guardrails Hub. Please install\n", - " and import them from the hub instead. All validators will be\n", - " removed from this module in the next major release.\n", - "\n", - " Install with: `guardrails hub install hub:///`\n", - " Import as: from guardrails.hub import `ValidatorName`\n", - " \n", - " warn(\n", - "\n", - "HTTP Request: POST https://api.cohere.ai/v1/chat \"HTTP/1.1 200 OK\"\n", - "Diffusion not supported. Skipping import.\n", - "HTTP Request: POST https://api.cohere.ai/v1/chat \"HTTP/1.1 200 OK\"\n" + "/Users/calebcourier/Projects/gr-mono/guardrails/docs/examples/.venv/lib/python3.11/site-packages/guardrails/validatorsattr.py:307: UserWarning: Validator 1-indexed is not installed!\n", + " warnings.warn(f\"Validator {validator_name} is not installed!\")\n" ] } ], @@ -272,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "0e910d87", "metadata": {}, "outputs": [ @@ -343,47 +252,47 @@ " │ │ { │ │\n", " │ │ \"user_id\": 2, │ │\n", " │ │ \"user_name\": \"Jane Smith\", │ │\n", - " │ │ \"num_orders\": 10 │ │\n", + " │ │ \"num_orders\": 4 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 3, │ │\n", " │ │ \"user_name\": \"David Lee\", │ │\n", - " │ │ \"num_orders\": 4 │ │\n", + " │ │ \"num_orders\": 2 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 4, │ │\n", " │ │ \"user_name\": \"Rachelle Gonzalez\", │ │\n", - " │ │ \"num_orders\": 2 │ │\n", + " │ │ \"num_orders\": 1 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 5, │ │\n", - " │ │ \"user_name\": \"Peter Brown\", │ │\n", + " │ │ \"user_name\": \"Frank Anderson\", │ │\n", " │ │ \"num_orders\": 3 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 6, │ │\n", - " │ │ \"user_name\": \"Micheal Wilson\", │ │\n", + " │ │ \"user_name\": \"Lisa Taylor\", │ │\n", " │ │ \"num_orders\": 5 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 7, │ │\n", - " │ │ \"user_name\": \"Sarah Jones\", │ │\n", - " │ │ \"num_orders\": 0 │ │\n", + " │ │ \"user_name\": \"Peter Wilson\", │ │\n", + " │ │ \"num_orders\": 7 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 8, │ │\n", - " │ │ \"user_name\": \"Rachelle Perez\", │ │\n", - " │ │ \"num_orders\": 8 │ │\n", + " │ │ \"user_name\": \"Micheal Harris\", │ │\n", + " │ │ \"num_orders\": 4 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 9, │ │\n", - " │ │ \"user_name\": \"John Garcia\", │ │\n", - " │ │ \"num_orders\": 1 │ │\n", + " │ │ \"user_name\": \"Sarah Anderson\", │ │\n", + " │ │ \"num_orders\": 2 │ │\n", " │ │ }, │ │\n", " │ │ { │ │\n", " │ │ \"user_id\": 10, │ │\n", - " │ │ \"user_name\": \"Jane Martinez\", │ │\n", - " │ │ \"num_orders\": 7 │ │\n", + " │ │ \"user_name\": \"Jessica Taylor\", │ │\n", + " │ │ \"num_orders\": 1 │ │\n", " │ │ } │ │\n", " │ │ ] │ │\n", " │ │ } │ │\n", @@ -392,15 +301,15 @@ " │ │ { │ │\n", " │ │ 'user_orders': [ │ │\n", " │ │ {'user_id': 1, 'user_name': 'John Mcdonald', 'num_orders': 6}, │ │\n", - " │ │ {'user_id': 2, 'user_name': 'Jane Smith', 'num_orders': 10}, │ │\n", - " │ │ {'user_id': 3, 'user_name': 'David Lee', 'num_orders': 4}, │ │\n", - " │ │ {'user_id': 4, 'user_name': 'Rachelle Gonzalez', 'num_orders': 2}, │ │\n", - " │ │ {'user_id': 5, 'user_name': 'Peter Brown', 'num_orders': 3}, │ │\n", - " │ │ {'user_id': 6, 'user_name': 'Micheal Wilson', 'num_orders': 5}, │ │\n", - " │ │ {'user_id': 7, 'user_name': 'Sarah Jones', 'num_orders': 0}, │ │\n", - " │ │ {'user_id': 8, 'user_name': 'Rachelle Perez', 'num_orders': 8}, │ │\n", - " │ │ {'user_id': 9, 'user_name': 'John Garcia', 'num_orders': 1}, │ │\n", - " │ │ {'user_id': 10, 'user_name': 'Jane Martinez', 'num_orders': 7} │ │\n", + " │ │ {'user_id': 2, 'user_name': 'Jane Smith', 'num_orders': 4}, │ │\n", + " │ │ {'user_id': 3, 'user_name': 'David Lee', 'num_orders': 2}, │ │\n", + " │ │ {'user_id': 4, 'user_name': 'Rachelle Gonzalez', 'num_orders': 1}, │ │\n", + " │ │ {'user_id': 5, 'user_name': 'Frank Anderson', 'num_orders': 3}, │ │\n", + " │ │ {'user_id': 6, 'user_name': 'Lisa Taylor', 'num_orders': 5}, │ │\n", + " │ │ {'user_id': 7, 'user_name': 'Peter Wilson', 'num_orders': 7}, │ │\n", + " │ │ {'user_id': 8, 'user_name': 'Micheal Harris', 'num_orders': 4}, │ │\n", + " │ │ {'user_id': 9, 'user_name': 'Sarah Anderson', 'num_orders': 2}, │ │\n", + " │ │ {'user_id': 10, 'user_name': 'Jessica Taylor', 'num_orders': 1} │ │\n", " │ │ ] │ │\n", " │ │ } │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", @@ -472,47 +381,47 @@ " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 2,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Jane Smith\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 10\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 4\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 3,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"David Lee\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 4\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 2\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 4,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Rachelle Gonzalez\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 2\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 1\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 5,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Peter Brown\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Frank Anderson\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 3\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 6,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Micheal Wilson\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Lisa Taylor\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 5\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 7,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Sarah Jones\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 0\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Peter Wilson\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 7\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 8,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Rachelle Perez\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 8\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Micheal Harris\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 4\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 9,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"John Garcia\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 1\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Sarah Anderson\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 2\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m },\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m {\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_id\": 10,\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Jane Martinez\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 7\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"user_name\": \"Jessica Taylor\",\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"num_orders\": 1\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m }\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m ]\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", @@ -521,15 +430,15 @@ " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m 'user_orders': [\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 1, 'user_name': 'John Mcdonald', 'num_orders': 6},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 2, 'user_name': 'Jane Smith', 'num_orders': 10},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 3, 'user_name': 'David Lee', 'num_orders': 4},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 4, 'user_name': 'Rachelle Gonzalez', 'num_orders': 2},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 5, 'user_name': 'Peter Brown', 'num_orders': 3},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 6, 'user_name': 'Micheal Wilson', 'num_orders': 5},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 7, 'user_name': 'Sarah Jones', 'num_orders': 0},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 8, 'user_name': 'Rachelle Perez', 'num_orders': 8},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 9, 'user_name': 'John Garcia', 'num_orders': 1},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 10, 'user_name': 'Jane Martinez', 'num_orders': 7}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 2, 'user_name': 'Jane Smith', 'num_orders': 4},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 3, 'user_name': 'David Lee', 'num_orders': 2},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 4, 'user_name': 'Rachelle Gonzalez', 'num_orders': 1},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 5, 'user_name': 'Frank Anderson', 'num_orders': 3},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 6, 'user_name': 'Lisa Taylor', 'num_orders': 5},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 7, 'user_name': 'Peter Wilson', 'num_orders': 7},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 8, 'user_name': 'Micheal Harris', 'num_orders': 4},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 9, 'user_name': 'Sarah Anderson', 'num_orders': 2},\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m {'user_id': 10, 'user_name': 'Jessica Taylor', 'num_orders': 1}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m ]\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", @@ -563,7 +472,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/examples/no_secrets_in_generated_text.ipynb b/docs/examples/no_secrets_in_generated_text.ipynb index f90372c0e..05f0ecc16 100644 --- a/docs/examples/no_secrets_in_generated_text.ipynb +++ b/docs/examples/no_secrets_in_generated_text.ipynb @@ -50,8 +50,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n" + "/Users/calebcourier/Projects/gr-mono/guardrails/docs/examples/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/Users/calebcourier/Projects/gr-mono/guardrails/docs/examples/.venv/lib/python3.11/site-packages/guardrails/validators/__init__.py:51: FutureWarning: \n", + " Importing validators from `guardrails.validators` is deprecated.\n", + " All validators are now available in the Guardrails Hub. Please install\n", + " and import them from the hub instead. All validators will be\n", + " removed from this module in the next major release.\n", + "\n", + " Install with: `guardrails hub install hub:///`\n", + " Import as: from guardrails.hub import `ValidatorName`\n", + " \n", + " warn(\n" ] } ], @@ -177,7 +187,16 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/calebcourier/Projects/gr-mono/guardrails/docs/examples/.venv/lib/python3.11/site-packages/guardrails/prompt/base_prompt.py:59: FutureWarning: Prompt Primitives are moving! To keep the same behaviour, switch from `json` constants to `xml` constants. Example: ${gr.complete_json_suffix} -> ${gr.complete_xml_suffix}\n", + " warn(\n" + ] + } + ], "source": [ "guard = gd.Guard.from_rail_string(rail_str)" ] @@ -193,7 +212,16 @@ "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/calebcourier/Projects/gr-mono/guardrails/docs/examples/.venv/lib/python3.11/site-packages/guardrails/prompt/base_prompt.py:59: FutureWarning: Prompt Primitives are moving! To keep the same behaviour, switch from `json` constants to `xml` constants. Example: ${gr.complete_json_suffix} -> ${gr.complete_xml_suffix}\n", + " warn(\n" + ] + } + ], "source": [ "guard = gd.Guard.from_pydantic(output_class=ScrubbedCode, prompt=prompt)" ] @@ -211,6 +239,14 @@ "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/w2/ssf16z690zd7_4dggw0y5s_m0000gn/T/ipykernel_82075/3983563700.py:1: DeprecationWarning: 'Guard.base_prompt' is deprecated and will be removed in versions 0.5.x and beyond. Use 'Guard.history.last.prompt' instead.\n", + " print(guard.base_prompt)\n" + ] + }, { "data": { "text/html": [ @@ -295,14 +331,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import openai\n", "\n", "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create, model=\"text-davinci-003\", max_tokens=2048, temperature=0\n", + " openai.completions.create, model=\"gpt-3.5-turbo-instruct\", max_tokens=2048, temperature=0\n", ")" ] }, @@ -314,19 +350,11 @@ { "data": { "text/html": [ - "
{\n",
-       "    'api_help': 'curl -X POST -H \"Content-Type: application/json\" -H \"Authorization: Bearer YOUR_API_KEY\" -d \n",
-       "\\'{\"prompt\": \"Once upon a time\", \"max_tokens\": 100}\\' \n",
-       "\"https://api.openai.com/v1/engines/davinci-codex/completions\"'\n",
-       "}\n",
+       "
{'api_help': 'Show an example curl command for using openai Completion API'}\n",
        "
\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[32m'api_help'\u001b[0m: \u001b[32m'curl -X POST -H \"Content-Type: application/json\" -H \"Authorization: Bearer YOUR_API_KEY\" -d \u001b[0m\n", - "\u001b[32m\\'\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\"prompt\": \"Once upon a time\", \"max_tokens\": 100\u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\' \u001b[0m\n", - "\u001b[32m\"https://api.openai.com/v1/engines/davinci-codex/completions\"'\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "\u001b[1m{\u001b[0m\u001b[32m'api_help'\u001b[0m: \u001b[32m'Show an example curl command for using openai Completion API'\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, @@ -376,27 +404,21 @@ " │ │ format=\"1-indexed\" /></object>` => `{'baz': {'foo': 'Some String', 'index': 1}}` │ │\n", " │ │ │ │\n", " │ │ │ │\n", + " │ │ │ │\n", + " │ │ Json Output: │ │\n", + " │ │ │ │\n", + " │ │ │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", - " │ ╭───────────────────────────────────────────── Instructions ──────────────────────────────────────────────╮ │\n", - " │ │ You are a helpful assistant, able to express yourself purely through JSON, strictly and precisely │ │\n", - " │ │ adhering to the provided XML schemas. │ │\n", - " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Message History ────────────────────────────────────────────╮ │\n", " │ │ No message history. │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Raw LLM Output ─────────────────────────────────────────────╮ │\n", " │ │ { │ │\n", - " │ │ \"api_help\": \"curl -X POST -H \\\"Content-Type: application/json\\\" -H \\\"Authorization: Bearer │ │\n", - " │ │ YOUR_API_KEY\\\" -d '{\\\"prompt\\\": \\\"Once upon a time\\\", \\\"max_tokens\\\": 100}' │ │\n", - " │ │ \\\"https://api.openai.com/v1/engines/davinci-codex/completions\\\"\" │ │\n", + " │ │ \"api_help\": \"Show an example curl command for using openai Completion API\" │ │\n", " │ │ } │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭─────────────────────────────────────────── Validated Output ────────────────────────────────────────────╮ │\n", - " │ │ { │ │\n", - " │ │ 'api_help': 'curl -X POST -H \"Content-Type: application/json\" -H \"Authorization: Bearer │ │\n", - " │ │ YOUR_API_KEY\" -d \\'{\"prompt\": \"Once upon a time\", \"max_tokens\": 100}\\' │ │\n", - " │ │ \"https://api.openai.com/v1/engines/davinci-codex/completions\"' │ │\n", - " │ │ } │ │\n", + " │ │ {'api_help': 'Show an example curl command for using openai Completion API'} │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "
\n" @@ -433,27 +455,21 @@ " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mformat=\"1-indexed\" />` => `{'baz': {'foo': 'Some String', 'index': 1}}`\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", + " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", + " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mJson Output:\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", + " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", + " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", - " │ \u001b[48;2;255;240;242m╭─\u001b[0m\u001b[48;2;255;240;242m────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m Instructions \u001b[0m\u001b[48;2;255;240;242m─────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m─╮\u001b[0m │\n", - " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242mYou are a helpful assistant, able to express yourself purely through JSON, strictly and precisely \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", - " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242madhering to the provided XML schemas.\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", - " │ \u001b[48;2;255;240;242m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╭─\u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m Message History \u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m─╮\u001b[0m │\n", " │ \u001b[48;2;231;223;235m│\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235mNo message history.\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m│\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╭─\u001b[0m\u001b[48;2;245;245;220m───────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m Raw LLM Output \u001b[0m\u001b[48;2;245;245;220m────────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m─╮\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m{\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"api_help\": \"curl -X POST -H \\\"Content-Type: application/json\\\" -H \\\"Authorization: Bearer \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220mYOUR_API_KEY\\\" -d '{\\\"prompt\\\": \\\"Once upon a time\\\", \\\"max_tokens\\\": 100}' \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m\\\"https://api.openai.com/v1/engines/davinci-codex/completions\\\"\"\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \"api_help\": \"Show an example curl command for using openai Completion API\"\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╭─\u001b[0m\u001b[48;2;240;255;240m──────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m Validated Output \u001b[0m\u001b[48;2;240;255;240m───────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m─╮\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m 'api_help': 'curl -X POST -H \"Content-Type: application/json\" -H \"Authorization: Bearer \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240mYOUR_API_KEY\" -d \\'{\"prompt\": \"Once upon a time\", \"max_tokens\": 100}\\' \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m\"https://api.openai.com/v1/engines/davinci-codex/completions\"'\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{'api_help': 'Show an example curl command for using openai Completion API'}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] @@ -483,7 +499,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.11.9" }, "orig_nbformat": 4, "vscode": { diff --git a/docs/examples/response_is_on_topic.ipynb b/docs/examples/response_is_on_topic.ipynb index 2eefbbf54..58692a146 100644 --- a/docs/examples/response_is_on_topic.ipynb +++ b/docs/examples/response_is_on_topic.ipynb @@ -2,10 +2,32 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Installing hub:\u001b[35m/\u001b[0m\u001b[35m/tryolabs/\u001b[0m\u001b[95mrestricttotopic...\u001b[0m\n", + "\u001b[2K\u001b[32m[ ===]\u001b[0m Fetching manifestst\n", + "\u001b[2K\u001b[32m[== ]\u001b[0m Downloading dependencies Running command git clone --filter=blob:none --quiet https://github.com/tryolabs/restricttotopic.git /private/var/folders/w2/ssf16z690zd7_4dggw0y5s_m0000gn/T/pip-req-build-advwvzw9\n", + "\u001b[2K\u001b[32m[=== ]\u001b[0m Downloading dependencies\n", + "\u001b[1A\u001b[2K\u001b[?25l\u001b[32m[ ]\u001b[0m Running post-install setup\n", + "\u001b[1A\u001b[2K✅Successfully installed tryolabs/restricttotopic!\n", + "\n", + "\n", + "\u001b[1mImport validator:\u001b[0m\n", + "from guardrails.hub import RestrictToTopic\n", + "\n", + "\u001b[1mGet more info:\u001b[0m\n", + "\u001b[4;94mhttps://hub.guardrailsai.com/validator/tryolabs/restricttotopic\u001b[0m\n", + "\n" + ] + } + ], "source": [ + "\n", "!guardrails hub install hub://tryolabs/restricttotopic" ] }, @@ -46,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -124,22 +146,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Validation failed for field with errors: Most relevant topic is tablet.\n" + "Validation failed for field with errors: Invalid topics found: ['tablet', 'computer', 'phone']\n" ] } ], @@ -183,14 +197,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Validation failed for field with errors: Most relevant topic is tablet.\n" + "Validation failed for field with errors: Invalid topics found: ['tablet', 'computer', 'phone']\n" ] } ], @@ -229,21 +243,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Validation failed for field with errors: Most relevant topic is tablet.\n" + "Validation failed for field with errors: Invalid topics found: ['tablet']\n" ] } ], @@ -288,7 +295,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/examples/text_summarization_quality.ipynb b/docs/examples/text_summarization_quality.ipynb index cca6538eb..d40e8610b 100644 --- a/docs/examples/text_summarization_quality.ipynb +++ b/docs/examples/text_summarization_quality.ipynb @@ -74,10 +74,10 @@ "metadata": {}, "outputs": [], "source": [ - "with open('data/article1.txt', 'r') as file:\n", + "with open(\"data/article1.txt\", \"r\") as file:\n", " document = file.read()\n", " file.seek(0)\n", - " content = ''.join(line.strip() for line in file.readlines())" + " content = \"\".join(line.strip() for line in file.readlines())" ] }, { @@ -95,7 +95,8 @@ "source": [ "from string import Template\n", "\n", - "rail_str = Template(\"\"\"\n", + "rail_str = Template(\n", + " \"\"\"\n", "\n", "\n", "\n", @@ -115,7 +116,8 @@ "${gr.complete_json_suffix}\n", "\n", "\n", - "\"\"\").safe_substitute(document=document)" + "\"\"\"\n", + ").safe_substitute(document=document)" ] }, { @@ -132,6 +134,7 @@ "outputs": [], "source": [ "from pydantic import BaseModel, Field\n", + "\n", "from guardrails.hub import SimilarToDocument\n", "\n", "prompt = \"\"\"\n", @@ -142,10 +145,13 @@ "${gr.complete_json_suffix}\n", "\"\"\"\n", "\n", + "\n", "class DocumentSummary(BaseModel):\n", " summary: str = Field(\n", " description=\"Summarize the given document faithfully.\",\n", - " validators=[SimilarToDocument(document=f\"'{content}'\", threshold=0.60, on_fail=\"filter\")]\n", + " validators=[\n", + " SimilarToDocument(document=f\"'{content}'\", threshold=0.60, on_fail=\"filter\")\n", + " ],\n", " )" ] }, @@ -178,9 +184,9 @@ "metadata": {}, "outputs": [], "source": [ - "import guardrails as gd\n", + "from rich import print\n", "\n", - "from rich import print" + "import guardrails as gd" ] }, { @@ -354,11 +360,11 @@ "import openai\n", "\n", "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", - " prompt_params={'document': document},\n", - " model='text-davinci-003',\n", + " openai.chat.completions.create,\n", + " prompt_params={\"document\": document},\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=2048,\n", - " temperature=0\n", + " temperature=0,\n", ")\n", "\n", "print(f\"Validated Output: {validated_response}\")" @@ -609,10 +615,10 @@ "source": [ "raw_llm_response, validated_response, *rest = guard(\n", " openai.completions.create,\n", - " prompt_params={'document': open(\"data/article1.txt\", \"r\").read()},\n", - " model='text-ada-001',\n", + " prompt_params={\"document\": open(\"data/article1.txt\", \"r\").read()},\n", + " model=\"text-ada-001\",\n", " max_tokens=512,\n", - " temperature=0\n", + " temperature=0,\n", ")\n", "\n", "print(f\"Validated Output: {validated_response}\")" diff --git a/docs/examples/translation_to_specific_language.ipynb b/docs/examples/translation_to_specific_language.ipynb index 2c03601f8..c8dacdc39 100644 --- a/docs/examples/translation_to_specific_language.ipynb +++ b/docs/examples/translation_to_specific_language.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": { "tags": [] }, @@ -32,15 +32,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: alt-profanity-check in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (1.3.2)\n", - "Requirement already satisfied: joblib>=1.3.2 in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (from alt-profanity-check) (1.3.2)\n", - "Requirement already satisfied: scikit-learn==1.3.2 in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (from alt-profanity-check) (1.3.2)\n", - "Requirement already satisfied: scipy>=1.5.0 in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (from scikit-learn==1.3.2->alt-profanity-check) (1.9.3)\n", - "Requirement already satisfied: numpy<2.0,>=1.17.3 in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (from scikit-learn==1.3.2->alt-profanity-check) (1.24.4)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages (from scikit-learn==1.3.2->alt-profanity-check) (3.2.0)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + "Collecting alt-profanity-check\n", + " Downloading alt_profanity_check-1.5.0.tar.gz (758 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m759.0/759.0 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Installing backend dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: scikit-learn==1.5.0 in ./.venv/lib/python3.10/site-packages (from alt-profanity-check) (1.5.0)\n", + "Requirement already satisfied: joblib>=1.4.0 in ./.venv/lib/python3.10/site-packages (from alt-profanity-check) (1.4.2)\n", + "Requirement already satisfied: numpy>=1.19.5 in ./.venv/lib/python3.10/site-packages (from scikit-learn==1.5.0->alt-profanity-check) (1.26.4)\n", + "Requirement already satisfied: scipy>=1.6.0 in ./.venv/lib/python3.10/site-packages (from scikit-learn==1.5.0->alt-profanity-check) (1.13.1)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in ./.venv/lib/python3.10/site-packages (from scikit-learn==1.5.0->alt-profanity-check) (3.5.0)\n", + "Building wheels for collected packages: alt-profanity-check\n", + " Building wheel for alt-profanity-check (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for alt-profanity-check: filename=alt_profanity_check-1.5.0-py3-none-any.whl size=758311 sha256=e0f54f82189ad2c90aeb27cb9239175c71d38606836be9e4762fb64b2e2de0a0\n", + " Stored in directory: /Users/wyatt/Library/Caches/pip/wheels/18/c3/20/637574a9badb43cace85202ca31f49f47e3fe65e076459f3ed\n", + "Successfully built alt-profanity-check\n", + "Installing collected packages: alt-profanity-check\n", + "Successfully installed alt-profanity-check-1.5.0\n" ] } ], @@ -70,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": { "tags": [] }, @@ -79,14 +89,30 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n" + "/Users/wyatt/Projects/guardrails/docs/examples/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/Users/wyatt/Projects/guardrails/guardrails/validators/__init__.py:51: FutureWarning: \n", + " Importing validators from `guardrails.validators` is deprecated.\n", + " All validators are now available in the Guardrails Hub. Please install\n", + " and import them from the hub instead. All validators will be\n", + " removed from this module in the next major release.\n", + "\n", + " Install with: `guardrails hub install hub:///`\n", + " Import as: from guardrails.hub import `ValidatorName`\n", + " \n", + " warn(\n" ] } ], "source": [ "from profanity_check import predict\n", - "from guardrails.validators import Validator, register_validator, ValidationResult, PassResult, FailResult\n", + "from guardrails.validators import (\n", + " Validator,\n", + " register_validator,\n", + " ValidationResult,\n", + " PassResult,\n", + " FailResult,\n", + ")\n", "\n", "\n", "from typing import Dict, Any\n", @@ -113,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": { "tags": [] }, @@ -153,9 +179,23 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wyatt/Projects/guardrails/guardrails/validator_base.py:460: FutureWarning: Accessing `IsProfanityFree` using\n", + "`from guardrails.validators import IsProfanityFree` is deprecated and\n", + "support will be removed after version 0.5.x. Please switch to the Guardrails Hub syntax:\n", + "`from guardrails.hub import ProfanityFree` for future updates and support.\n", + "For additional details, please visit: https://hub.guardrailsai.com/validator/guardrails/profanity_free.\n", + "\n", + " warn(\n" + ] + } + ], "source": [ "from pydantic import BaseModel, Field\n", "\n", @@ -167,11 +207,12 @@ "${gr.complete_json_suffix}\n", "\"\"\"\n", "\n", + "\n", "class Translation(BaseModel):\n", " translated_statement: str = Field(\n", " description=\"Translate the given statement into english language\",\n", - " validators=[IsProfanityFree(on_fail=\"fix\")]\n", - " )" + " validators=[IsProfanityFree(on_fail=\"fix\")],\n", + " )" ] }, { @@ -198,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -216,11 +257,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wyatt/Projects/guardrails/guardrails/prompt/base_prompt.py:59: FutureWarning: Prompt Primitives are moving! To keep the same behaviour, switch from `json` constants to `xml` constants. Example: ${gr.complete_json_suffix} -> ${gr.complete_xml_suffix}\n", + " warn(\n" + ] + } + ], "source": [ "guard = gd.Guard.from_rail_string(rail_str)" ] @@ -234,9 +284,25 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wyatt/Projects/guardrails/guardrails/validator_base.py:460: FutureWarning: Accessing `IsProfanityFree` using\n", + "`from guardrails.validators import IsProfanityFree` is deprecated and\n", + "support will be removed after version 0.5.x. Please switch to the Guardrails Hub syntax:\n", + "`from guardrails.hub import ProfanityFree` for future updates and support.\n", + "For additional details, please visit: https://hub.guardrailsai.com/validator/guardrails/profanity_free.\n", + "\n", + " warn(\n", + "/Users/wyatt/Projects/guardrails/guardrails/prompt/base_prompt.py:59: FutureWarning: Prompt Primitives are moving! To keep the same behaviour, switch from `json` constants to `xml` constants. Example: ${gr.complete_json_suffix} -> ${gr.complete_xml_suffix}\n", + " warn(\n" + ] + } + ], "source": [ "guard = gd.Guard.from_pydantic(output_class=Translation, prompt=prompt)" ] @@ -250,11 +316,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": { "tags": [] }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8n/8qwytjb11kj_46_w3n2v4jzw0000gn/T/ipykernel_6330/3983563700.py:1: DeprecationWarning: 'Guard.base_prompt' is deprecated and will be removed in versions 0.5.x and beyond. Use 'Guard.history.last.prompt' instead.\n", + " print(guard.base_prompt)\n" + ] + }, { "data": { "text/html": [ @@ -350,21 +424,14 @@ "execution_count": 13, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "HTTP Request: POST https://api.openai.com/v1/completions \"HTTP/1.1 200 OK\"\n" - ] - }, { "data": { "text/html": [ - "
Validated Output: {'translated_statement': 'Chicken Quesadilla'}\n",
+       "
Validated Output: {'translated_statement': 'Chicken quesadilla'}\n",
        "
\n" ], "text/plain": [ - "Validated Output: \u001b[1m{\u001b[0m\u001b[32m'translated_statement'\u001b[0m: \u001b[32m'Chicken Quesadilla'\u001b[0m\u001b[1m}\u001b[0m\n" + "Validated Output: \u001b[1m{\u001b[0m\u001b[32m'translated_statement'\u001b[0m: \u001b[32m'Chicken quesadilla'\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, @@ -375,9 +442,9 @@ "import openai\n", "\n", "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", + " openai.chat.completions.create,\n", " prompt_params={\"statement_to_be_translated\": \"quesadilla de pollo\"},\n", - " model=\"text-davinci-003\",\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=2048,\n", " temperature=0,\n", ")\n", @@ -394,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -432,19 +499,19 @@ " │ │ format=\"1-indexed\" /></object>` => `{'baz': {'foo': 'Some String', 'index': 1}}` │ │\n", " │ │ │ │\n", " │ │ │ │\n", - " │ │ │ │\n", - " │ │ Json Output: │ │\n", - " │ │ │ │\n", - " │ │ │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", + " │ ╭───────────────────────────────────────────── Instructions ──────────────────────────────────────────────╮ │\n", + " │ │ You are a helpful assistant, able to express yourself purely through JSON, strictly and precisely │ │\n", + " │ │ adhering to the provided XML schemas. │ │\n", + " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Message History ────────────────────────────────────────────╮ │\n", " │ │ No message history. │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Raw LLM Output ─────────────────────────────────────────────╮ │\n", - " │ │ {\"translated_statement\": \"Chicken Quesadilla\"} │ │\n", + " │ │ {\"translated_statement\":\"Chicken quesadilla\"} │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭─────────────────────────────────────────── Validated Output ────────────────────────────────────────────╮ │\n", - " │ │ {'translated_statement': 'Chicken Quesadilla'} │ │\n", + " │ │ {'translated_statement': 'Chicken quesadilla'} │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "
\n" @@ -482,19 +549,19 @@ " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mformat=\"1-indexed\" />` => `{'baz': {'foo': 'Some String', 'index': 1}}`\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mJson Output:\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m╭─\u001b[0m\u001b[48;2;255;240;242m────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m Instructions \u001b[0m\u001b[48;2;255;240;242m─────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m─╮\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242mYou are a helpful assistant, able to express yourself purely through JSON, strictly and precisely \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242madhering to the provided XML schemas.\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╭─\u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m Message History \u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m─╮\u001b[0m │\n", " │ \u001b[48;2;231;223;235m│\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235mNo message history.\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m│\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╭─\u001b[0m\u001b[48;2;245;245;220m───────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m Raw LLM Output \u001b[0m\u001b[48;2;245;245;220m────────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m─╮\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m{\"translated_statement\": \"Chicken Quesadilla\"}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m{\"translated_statement\":\"Chicken quesadilla\"}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╭─\u001b[0m\u001b[48;2;240;255;240m──────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m Validated Output \u001b[0m\u001b[48;2;240;255;240m───────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m─╮\u001b[0m │\n", - " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{'translated_statement': 'Chicken Quesadilla'}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", + " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{'translated_statement': 'Chicken quesadilla'}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] @@ -519,18 +586,11 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "metadata": { "tags": [] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "HTTP Request: POST https://api.openai.com/v1/completions \"HTTP/1.1 200 OK\"\n" - ] - }, { "data": { "text/html": [ @@ -547,9 +607,9 @@ ], "source": [ "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", + " openai.chat.completions.create,\n", " prompt_params={\"statement_to_be_translated\": \"убей себя\"},\n", - " model=\"text-davinci-003\",\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=2048,\n", " temperature=0,\n", ")\n", @@ -567,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -605,16 +665,16 @@ " │ │ format=\"1-indexed\" /></object>` => `{'baz': {'foo': 'Some String', 'index': 1}}` │ │\n", " │ │ │ │\n", " │ │ │ │\n", - " │ │ │ │\n", - " │ │ Json Output: │ │\n", - " │ │ │ │\n", - " │ │ │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", + " │ ╭───────────────────────────────────────────── Instructions ──────────────────────────────────────────────╮ │\n", + " │ │ You are a helpful assistant, able to express yourself purely through JSON, strictly and precisely │ │\n", + " │ │ adhering to the provided XML schemas. │ │\n", + " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Message History ────────────────────────────────────────────╮ │\n", " │ │ No message history. │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭──────────────────────────────────────────── Raw LLM Output ─────────────────────────────────────────────╮ │\n", - " │ │ {\"translated_statement\": \"Kill yourself\"} │ │\n", + " │ │ {\"translated_statement\":\"Kill yourself\"} │ │\n", " │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯ │\n", " │ ╭─────────────────────────────────────────── Validated Output ────────────────────────────────────────────╮ │\n", " │ │ {'translated_statement': ''} │ │\n", @@ -655,16 +715,16 @@ " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mformat=\"1-indexed\" />` => `{'baz': {'foo': 'Some String', 'index': 1}}`\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255mJson Output:\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", - " │ \u001b[48;2;240;248;255m│\u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m \u001b[0m\u001b[48;2;240;248;255m│\u001b[0m │\n", " │ \u001b[48;2;240;248;255m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m╭─\u001b[0m\u001b[48;2;255;240;242m────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m Instructions \u001b[0m\u001b[48;2;255;240;242m─────────────────────────────────────────────\u001b[0m\u001b[48;2;255;240;242m─╮\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242mYou are a helpful assistant, able to express yourself purely through JSON, strictly and precisely \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m│\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242madhering to the provided XML schemas.\u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m \u001b[0m\u001b[48;2;255;240;242m│\u001b[0m │\n", + " │ \u001b[48;2;255;240;242m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╭─\u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m Message History \u001b[0m\u001b[48;2;231;223;235m───────────────────────────────────────────\u001b[0m\u001b[48;2;231;223;235m─╮\u001b[0m │\n", " │ \u001b[48;2;231;223;235m│\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235mNo message history.\u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m \u001b[0m\u001b[48;2;231;223;235m│\u001b[0m │\n", " │ \u001b[48;2;231;223;235m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╭─\u001b[0m\u001b[48;2;245;245;220m───────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m Raw LLM Output \u001b[0m\u001b[48;2;245;245;220m────────────────────────────────────────────\u001b[0m\u001b[48;2;245;245;220m─╮\u001b[0m │\n", - " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m{\"translated_statement\": \"Kill yourself\"}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", + " │ \u001b[48;2;245;245;220m│\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m{\"translated_statement\":\"Kill yourself\"}\u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m \u001b[0m\u001b[48;2;245;245;220m│\u001b[0m │\n", " │ \u001b[48;2;245;245;220m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m │\n", " │ \u001b[48;2;240;255;240m╭─\u001b[0m\u001b[48;2;240;255;240m──────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m Validated Output \u001b[0m\u001b[48;2;240;255;240m───────────────────────────────────────────\u001b[0m\u001b[48;2;240;255;240m─╮\u001b[0m │\n", " │ \u001b[48;2;240;255;240m│\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m{'translated_statement': ''}\u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m \u001b[0m\u001b[48;2;240;255;240m│\u001b[0m │\n", @@ -697,7 +757,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.10.3" } }, "nbformat": 4, diff --git a/docs/examples/translation_with_quality_check.ipynb b/docs/examples/translation_with_quality_check.ipynb index 3429514fb..eb627fdba 100644 --- a/docs/examples/translation_with_quality_check.ipynb +++ b/docs/examples/translation_with_quality_check.ipynb @@ -360,10 +360,10 @@ "statement = \"Ich habe keine Ahnung, was ich hier schreiben soll.\"\n", "\n", "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", + " openai.chat.completions.create,\n", " prompt_params={\"statement_to_be_translated\": statement},\n", " metadata={\"translation_source\": statement},\n", - " model=\"text-davinci-003\",\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=1024,\n", " temperature=0,\n", ")\n", diff --git a/docs/examples/valid_chess_moves.ipynb b/docs/examples/valid_chess_moves.ipynb index 46b5a5c44..57980d31d 100644 --- a/docs/examples/valid_chess_moves.ipynb +++ b/docs/examples/valid_chess_moves.ipynb @@ -339,13 +339,13 @@ "import openai\n", "\n", "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", + " openai.chat.completions.create,\n", " prompt_params={\n", " \"board_state\": str(board.move_stack)\n", " if board.move_stack\n", " else \"Starting position.\"\n", " },\n", - " model=\"text-davinci-003\",\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=2048,\n", " temperature=0.3,\n", ")" @@ -476,13 +476,13 @@ ], "source": [ "raw_llm_response, validated_response, *rest = guard(\n", - " openai.completions.create,\n", + " openai.chat.completions.create,\n", " prompt_params={\n", " \"board_state\": str(board.move_stack)\n", " if board.move_stack\n", " else \"Starting position.\"\n", " },\n", - " model=\"text-davinci-003\",\n", + " model=\"gpt-3.5-turbo\",\n", " max_tokens=2048,\n", " temperature=0.3,\n", ")" diff --git a/guardrails/api_client.py b/guardrails/api_client.py index bf8b7323b..55af0fdaf 100644 --- a/guardrails/api_client.py +++ b/guardrails/api_client.py @@ -1,9 +1,11 @@ +import json import os -from typing import Optional +from typing import Generator, Optional +import requests from guardrails_api_client import AuthenticatedClient from guardrails_api_client.api.guard import update_guard, validate -from guardrails_api_client.models import Guard, ValidatePayload +from guardrails_api_client.models import Guard, ValidatePayload, ValidationOutput from guardrails_api_client.types import UNSET from httpx import Timeout @@ -49,3 +51,34 @@ def validate( body=payload, x_openai_api_key=_openai_api_key, ) + + def stream_validate( + self, + guard: Guard, + payload: ValidatePayload, + openai_api_key: Optional[str] = None, + ) -> Generator[ValidationOutput, None, None]: + _openai_api_key = ( + openai_api_key + if openai_api_key is not None + else os.environ.get("OPENAI_API_KEY", UNSET) + ) + + url = f"{self.base_url}/guards/{guard.name}/validate" + headers = { + "Content-Type": "application/json", + "x-openai-api-key": _openai_api_key, + } + + s = requests.Session() + + with s.post(url, json=payload.to_dict(), headers=headers, stream=True) as resp: + for line in resp.iter_lines(): + if not resp.ok: + raise ValueError( + f"status_code: {resp.status_code}" + " reason: {resp.reason} text: {resp.text}" + ) + if line: + json_output = json.loads(line) + yield ValidationOutput.from_dict(json_output) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 77d5e5ccf..561edd2d3 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -2,6 +2,7 @@ import inspect from typing import ( Any, + AsyncIterable, Awaitable, Callable, Dict, @@ -9,16 +10,24 @@ List, Optional, Union, + cast, ) +from guardrails_api_client.models import ValidatePayload, ValidationOutput + from guardrails import Guard from guardrails.classes import OT, ValidationOutcome from guardrails.classes.history import Call from guardrails.classes.history.call_inputs import CallInputs from guardrails.llm_providers import get_async_llm_ask, model_is_supported_server_side from guardrails.logger import set_scope -from guardrails.run import AsyncRunner -from guardrails.stores.context import set_call_kwargs, set_tracer, set_tracer_context +from guardrails.run import AsyncRunner, AsyncStreamRunner +from guardrails.stores.context import ( + get_call_kwarg, + set_call_kwargs, + set_tracer, + set_tracer_context, +) class AsyncGuard(Guard): @@ -139,6 +148,7 @@ async def __call( full_schema_reask=full_schema_reask, args=list(args), kwargs=kwargs, + stream=kwargs.get("stream"), ) call_log = Call(inputs=call_inputs) set_scope(str(id(call_log))) @@ -147,7 +157,7 @@ async def __call( if self._api_client is not None and model_is_supported_server_side( llm_api, *args, **kwargs ): - return self._call_server( + result = self._call_server( llm_api=llm_api, num_reasks=self.num_reasks, prompt_params=prompt_params, @@ -156,30 +166,24 @@ async def __call( *args, **kwargs, ) - - # If the LLM API is not async, fail - # FIXME: it seems like this check isn't actually working? - if not inspect.isawaitable(llm_api) and not inspect.iscoroutinefunction( - llm_api - ): - raise RuntimeError( - f"The LLM API `{llm_api.__name__}` is not a coroutine. " - "Please use an async LLM API." + else: + result = self._call_async( + llm_api, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + prompt=prompt, + instructions=instructions, + msg_history=msg_history, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, ) - # Otherwise, call the LLM - return await self._call_async( - llm_api, - prompt_params=prompt_params, - num_reasks=self.num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, - metadata=metadata, - full_schema_reask=full_schema_reask, - call_log=call_log, - *args, - **kwargs, - ) + + if inspect.isawaitable(result): + return await result + return result guard_context = contextvars.Context() return await guard_context.run( @@ -210,7 +214,7 @@ async def _call_async( call_log: Call, *args, **kwargs, - ) -> ValidationOutcome[OT]: + ) -> Union[ValidationOutcome[OT], AsyncIterable[ValidationOutcome[OT]]]: """Call the LLM asynchronously and validate the output. Args: @@ -238,24 +242,48 @@ async def _call_async( "You must provide a prompt if msg_history is empty. " "Alternatively, you can provide a prompt in the RAIL spec." ) - - runner = AsyncRunner( - instructions=instructions_obj, - prompt=prompt_obj, - msg_history=msg_history_obj, - api=get_async_llm_ask(llm_api, *args, **kwargs), - prompt_schema=self.rail.prompt_schema, - instructions_schema=self.rail.instructions_schema, - msg_history_schema=self.rail.msg_history_schema, - output_schema=self.rail.output_schema, - num_reasks=num_reasks, - metadata=metadata, - base_model=self.base_model, - full_schema_reask=full_schema_reask, - disable_tracer=self._disable_tracer, - ) - call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) - return ValidationOutcome[OT].from_guard_history(call) + if kwargs.get("stream", False): + runner = AsyncStreamRunner( + instructions=instructions_obj, + prompt=prompt_obj, + msg_history=msg_history_obj, + api=get_async_llm_ask(llm_api, *args, **kwargs), + prompt_schema=self.rail.prompt_schema, + instructions_schema=self.rail.instructions_schema, + msg_history_schema=self.rail.msg_history_schema, + output_schema=self.rail.output_schema, + num_reasks=num_reasks, + metadata=metadata, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + # Here we have an async generator + async_generator = runner.async_run( + call_log=call_log, prompt_params=prompt_params + ) + return async_generator + + else: + runner = AsyncRunner( + instructions=instructions_obj, + prompt=prompt_obj, + msg_history=msg_history_obj, + api=get_async_llm_ask(llm_api, *args, **kwargs), + prompt_schema=self.rail.prompt_schema, + instructions_schema=self.rail.instructions_schema, + msg_history_schema=self.rail.msg_history_schema, + output_schema=self.rail.output_schema, + num_reasks=num_reasks, + metadata=metadata, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + call = await runner.async_run( + call_log=call_log, prompt_params=prompt_params + ) + return ValidationOutcome[OT].from_guard_history(call) async def parse( self, @@ -347,6 +375,7 @@ async def __parse( full_schema_reask=full_schema_reask, args=list(args), kwargs=kwargs, + stream=kwargs.get("stream"), ) call_log = Call(inputs=call_inputs) set_scope(str(id(call_log))) @@ -366,31 +395,17 @@ async def __parse( **kwargs, ) - # FIXME: checking not llm_api because it can still fall back on defaults and - # function as expected. We should handle this better. - if ( - not llm_api - or inspect.iscoroutinefunction(llm_api) - or inspect.isasyncgenfunction(llm_api) - ): - return await self._async_parse( - llm_output, - metadata, - llm_api=llm_api, - num_reasks=self.num_reasks, - prompt_params=prompt_params, - full_schema_reask=full_schema_reask, - call_log=call_log, - *args, - **kwargs, - ) - - else: - raise NotImplementedError( - "AsyncGuard does not support non-async LLM APIs. " - "Please use the synchronous API Guard or supply an asynchronous " - "LLM API." - ) + return await self._async_parse( + llm_output, + metadata, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) guard_context = contextvars.Context() return await guard_context.run( @@ -447,3 +462,50 @@ async def _async_parse( call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) + + async def _stream_server_call( + self, + *, + payload: Dict[str, Any], + llm_output: Optional[str] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + metadata: Optional[Dict] = {}, + full_schema_reask: Optional[bool] = True, + call_log: Optional[Call], + ) -> AsyncIterable[ValidationOutcome[OT]]: + # TODO: Once server side supports async streaming, this function will need to + # yield async generators, not generators + if self._api_client: + validation_output: Optional[ValidationOutput] = None + response = self._api_client.stream_validate( + guard=self, # type: ignore + payload=ValidatePayload.from_dict(payload), + openai_api_key=get_call_kwarg("api_key"), + ) + for fragment in response: + validation_output = fragment + if not validation_output: + yield ValidationOutcome[OT]( + raw_llm_output=None, + validated_output=None, + validation_passed=False, + error="The response from the server was empty!", + ) + yield ValidationOutcome[OT]( + raw_llm_output=validation_output.raw_llm_response, # type: ignore + validated_output=cast(OT, validation_output.validated_output), + validation_passed=validation_output.result, + ) + if validation_output: + self._construct_history_from_server_response( + validation_output=validation_output, + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + ) + else: + raise ValueError("Guard does not have an api client!") diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index 503cefb9a..6598e3c12 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -20,7 +20,7 @@ sub_reasks_with_fixed_values, ) from guardrails.utils.safe_get import get_value_from_path -from guardrails.validator_base import Filter, Refrain +from guardrails.validator_base import Filter, Refrain, ValidationResult # We can't inherit from Iteration because python @@ -353,6 +353,7 @@ def failed_validations(self) -> Stack[ValidatorLogs]: log for log in self.validator_logs if log.validation_result is not None + and isinstance(log.validation_result, ValidationResult) and log.validation_result.outcome == "fail" ] ) diff --git a/guardrails/classes/history/inputs.py b/guardrails/classes/history/inputs.py index fd9e6762a..07a34920e 100644 --- a/guardrails/classes/history/inputs.py +++ b/guardrails/classes/history/inputs.py @@ -46,3 +46,7 @@ class Inputs(ArbitraryModel): "or at the field level.", default=None, ) + stream: Optional[bool] = Field( + description="Whether to use streaming.", + default=False, + ) diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py index 1eb9e12dd..a8fb67ada 100644 --- a/guardrails/classes/history/iteration.py +++ b/guardrails/classes/history/iteration.py @@ -15,6 +15,7 @@ from guardrails.utils.logs_utils import ValidatorLogs from guardrails.utils.pydantic_utils import ArbitraryModel from guardrails.utils.reask_utils import ReAsk +from guardrails.validator_base import ErrorSpan class Iteration(ArbitraryModel): @@ -136,6 +137,13 @@ def reasks(self) -> Sequence[ReAsk]: def validator_logs(self) -> List[ValidatorLogs]: """The results of each individual validation performed on the LLM response during this iteration.""" + if self.inputs.stream: + filtered_logs = [ + log + for log in self.outputs.validator_logs + if log.validation_result and log.validation_result.validated_chunk + ] + return filtered_logs return self.outputs.validator_logs @property @@ -155,6 +163,14 @@ def failed_validations(self) -> List[ValidatorLogs]: iteration.""" return self.outputs.failed_validations + @property + def error_spans_in_output(self) -> List[ErrorSpan]: + """The error spans from the LLM response. + + These indices are relative to the complete LLM output. + """ + return self.outputs.error_spans_in_output + @property def status(self) -> str: """Representation of the end state of this iteration. diff --git a/guardrails/classes/history/outputs.py b/guardrails/classes/history/outputs.py index 9d4d19544..ead39de20 100644 --- a/guardrails/classes/history/outputs.py +++ b/guardrails/classes/history/outputs.py @@ -8,7 +8,7 @@ from guardrails.utils.logs_utils import ValidatorLogs from guardrails.utils.pydantic_utils import ArbitraryModel from guardrails.utils.reask_utils import ReAsk -from guardrails.validator_base import FailResult +from guardrails.validator_base import ErrorSpan, FailResult, ValidationResult class Outputs(ArbitraryModel): @@ -71,10 +71,36 @@ def failed_validations(self) -> List[ValidatorLogs]: log for log in self.validator_logs if log.validation_result is not None + and isinstance(log.validation_result, ValidationResult) and log.validation_result.outcome == "fail" ] ) + @property + def error_spans_in_output(self) -> List[ErrorSpan]: + """The error spans from the LLM response. + + These indices are relative to the complete LLM output. + """ + total_len = 0 + spans_in_output = [] + for log in self.validator_logs: + result = log.validation_result + if isinstance(result, FailResult): + if result.error_spans is not None: + for error_span in result.error_spans: + spans_in_output.append( + ErrorSpan( + start=error_span.start + total_len, + end=error_span.end + total_len, + reason=error_span.reason, + ) + ) + if isinstance(result, ValidationResult): + if result and result.validated_chunk is not None: + total_len += len(result.validated_chunk) + return spans_in_output + @property def status(self) -> str: """Representation of the end state of the validation run. diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index d4d94ad6f..931b2383f 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -156,6 +156,7 @@ def install_hub_module( flags=[f"--path={install_directory}"], format=json_format, quiet=quiet, + no_color=True, ) # throw if inspect_output is a string. Mostly for pyright @@ -169,7 +170,7 @@ def install_hub_module( .get("metadata", {}) # type: ignore .get("requires_dist", []) # type: ignore ) - requirements = filter(lambda dep: "extra" not in dep, dependencies) + requirements = list(filter(lambda dep: "extra" not in dep, dependencies)) for req in requirements: if "git+" in req: install_spec = req.replace(" ", "") diff --git a/guardrails/cli/hub/utils.py b/guardrails/cli/hub/utils.py index 384e76ee2..6c0b2d968 100644 --- a/guardrails/cli/hub/utils.py +++ b/guardrails/cli/hub/utils.py @@ -1,5 +1,6 @@ import json import os +import re import subprocess import sys from email.parser import BytesHeaderParser @@ -20,33 +21,44 @@ def pip_process( flags: List[str] = [], format: Union[Literal["string"], Literal["json"]] = string_format, quiet: bool = False, + no_color: bool = False, ) -> Union[str, dict]: try: if not quiet: logger.debug(f"running pip {action} {' '.join(flags)} {package}") - command = [sys.executable, "-m", "pip", action] + flags + command = [sys.executable, "-m", "pip", action] + command.extend(flags) if package: command.append(package) + env = dict(os.environ) + if no_color: + env["NO_COLOR"] = "true" if not quiet: logger.debug(f"decoding output from pip {action} {package}") - output = subprocess.check_output(command) + output = subprocess.check_output(command, env=env) else: - output = subprocess.check_output(command, stderr=subprocess.DEVNULL) + output = subprocess.check_output( + command, stderr=subprocess.DEVNULL, env=env + ) if format == json_format: parsed = BytesHeaderParser().parsebytes(output) try: - return json.loads(str(parsed)) + remove_color_codes = re.compile(r"\x1b\[[0-9;]*m") + parsed_as_string = re.sub( + remove_color_codes, "", parsed.as_string().strip() + ) + return json.loads(parsed_as_string) except Exception: logger.debug( f"JSON parse exception in decoding output from pip \ {action} {package}. Falling back to accumulating the byte stream", ) - accumulator = {} - for key, value in parsed.items(): - accumulator[key] = value - return accumulator + accumulator = {} + for key, value in parsed.items(): + accumulator[key] = value + return accumulator return str(output.decode()) except subprocess.CalledProcessError as exc: logger.error( diff --git a/guardrails/cli/server/module_manifest.py b/guardrails/cli/server/module_manifest.py index 454d718a8..61a5c7ccf 100644 --- a/guardrails/cli/server/module_manifest.py +++ b/guardrails/cli/server/module_manifest.py @@ -28,6 +28,13 @@ class ModuleTags(Serializeable): process_requirements: Optional[List[str]] = field(default_factory=list) +@dataclass +class ModelAuth(Serializeable): + type: str + name: str + displayName: Optional[str] = None + + @dataclass class ModuleManifest(Serializeable): id: str @@ -43,6 +50,7 @@ class ModuleManifest(Serializeable): requires_auth: Optional[bool] = True post_install: Optional[str] = None index: Optional[str] = None + required_model_auth: Optional[List[ModelAuth]] = field(default_factory=list) # @override @classmethod diff --git a/guardrails/document_store.py b/guardrails/document_store.py index 4a9a31298..8e6ad7c25 100644 --- a/guardrails/document_store.py +++ b/guardrails/document_store.py @@ -238,21 +238,23 @@ def get_pages_for_for_indexes(self, indexes: List[int]) -> List[Page]: except ImportError: class FallbackEphemeralDocumentStore: - def __init__(self): + def __init__(self, *args, **kwargs): + # Why don't we just raise this when the import + # error occurs instead of at runtime? raise ImportError( "SQLAlchemy is required for EphemeralDocumentStore" "Please install it using `poetry add SqlAlchemy`" ) class FallbackSQLDocument: - def __init__(self): + def __init__(self, *args, **kwargs): raise ImportError( "SQLAlchemy is required for SQLDocument" "Please install it using `poetry add SqlAlchemy`" ) class FallbackSQLMetadataStore: - def __init__(self): + def __init__(self, *args, **kwargs): raise ImportError( "SQLAlchemy is required for SQLMetadataStore" "Please install it using `poetry add SqlAlchemy`" diff --git a/guardrails/guard.py b/guardrails/guard.py index 63e9e6da9..f8fc4627b 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -10,6 +10,7 @@ Awaitable, Callable, Dict, + Generator, Generic, Iterable, List, @@ -66,6 +67,7 @@ set_tracer, set_tracer_context, ) +from guardrails.utils.api_utils import extract_serializeable_metadata from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import FieldReAsk @@ -146,18 +148,7 @@ def __init__( self.description = description self.name = name - api_key = os.environ.get("GUARDRAILS_API_KEY") - if api_key is not None: - if self.name is None: - self.name = f"gr-{str(self._guard_id)}" - logger.warn("Warning: No name passed to guard!") - logger.warn( - "Use this auto-generated name to re-use this guard: {name}".format( - name=self.name - ) - ) - self._api_client = GuardrailsApiClient(api_key=api_key) - self.upsert_guard() + self._save() @property @deprecated( @@ -643,6 +634,7 @@ def __call( full_schema_reask=full_schema_reask, args=list(args), kwargs=kwargs, + stream=kwargs.get("stream"), ) call_log = Call(inputs=call_inputs) set_scope(str(id(call_log))) @@ -653,8 +645,12 @@ def __call( ): return self._call_server( llm_api=llm_api, - num_reasks=self.num_reasks, prompt_params=prompt_params, + num_reasks=self.num_reasks, + prompt=prompt, + instructions=instructions, + msg_history=msg_history, + metadata=metadata, full_schema_reask=full_schema_reask, call_log=call_log, *args, @@ -999,6 +995,7 @@ def __parse( full_schema_reask=full_schema_reask, args=list(args), kwargs=kwargs, + stream=kwargs.get("stream"), ) call_log = Call(inputs=call_inputs) set_scope(str(id(call_log))) @@ -1009,6 +1006,7 @@ def __parse( ): return self._call_server( llm_output=llm_output, + metadata=metadata, llm_api=llm_api, num_reasks=self.num_reasks, prompt_params=prompt_params, @@ -1148,6 +1146,18 @@ async def _async_parse( return ValidationOutcome[OT].from_guard_history(call) + def error_spans_in_output(self): + try: + call = self.history.last + if call: + iter = call.iterations.last + if iter: + llm_spans = iter.error_spans_in_output + return llm_spans + return [] + except (AttributeError, TypeError): + return [] + @deprecated( """The `with_prompt_validation` method is deprecated, and will be removed in 0.5.x. Instead, please use @@ -1295,6 +1305,7 @@ def use( """ hydrated_validator = get_validator(validator, *args, **kwargs) self.__add_validator(hydrated_validator, on=on) + self._save() return self @overload @@ -1336,6 +1347,7 @@ def use_many( for v in validators: hydrated_validator = get_validator(v) self.__add_validator(hydrated_validator, on=on) + self._save() return self def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: @@ -1358,6 +1370,7 @@ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: if self.rail.output_schema.reask_instructions_template else None, ) + self._save() return self.parse(llm_output=llm_output, *args, **kwargs) @@ -1417,39 +1430,114 @@ def upsert_guard(self): else: raise ValueError("Guard does not have an api client!") - def _call_server( + def _construct_history_from_server_response( self, - *args, - llm_output: Optional[str] = None, + *, + validation_output: Optional[ValidationOutput] = None, llm_api: Optional[Callable] = None, + llm_output: Optional[str] = None, num_reasks: Optional[int] = None, prompt_params: Optional[Dict] = None, metadata: Optional[Dict] = {}, full_schema_reask: Optional[bool] = True, call_log: Optional[Call], - # prompt: Optional[str], - # instructions: Optional[str], - # msg_history: Optional[List[Dict]], - **kwargs, + stream: Optional[bool] = False, ): + call_log = call_log or Call() + if llm_api is not None: + llm_api = get_llm_ask(llm_api) + if asyncio.iscoroutinefunction(llm_api): + llm_api = get_async_llm_ask(llm_api) + session_history = ( + validation_output.session_history + if validation_output is not None and validation_output.session_history + else [] + ) + history: History + for history in session_history: + history_events: Optional[List[HistoryEvent]] = ( # type: ignore + history.history if history.history != UNSET else None + ) + if history_events is None: + continue + + iterations = [ + Iteration( + inputs=Inputs( + llm_api=llm_api, + llm_output=llm_output, + instructions=( + Instructions(h.instructions) if h.instructions else None + ), + prompt=( + Prompt(h.prompt.source) # type: ignore + if h.prompt is not None and h.prompt != UNSET + else None + ), + prompt_params=prompt_params, + num_reasks=(num_reasks or 0), + metadata=metadata, + full_schema_reask=full_schema_reask, + stream=stream, + ), + outputs=Outputs( + llm_response_info=LLMResponse(output=h.output), # type: ignore + raw_output=h.output, + parsed_output=( + h.parsed_output.to_dict() + if isinstance(h.parsed_output, AnyObject) + else h.parsed_output + ), + validation_output=( + h.validated_output.to_dict() + if isinstance(h.validated_output, AnyObject) + else h.validated_output + ), + reasks=list( + [ + FieldReAsk( + incorrect_value=r.to_dict().get("incorrect_value"), + path=r.to_dict().get("path"), + fail_results=[ + FailResult( + error_message=r.to_dict().get( + "error_message" + ), + fix_value=r.to_dict().get("fix_value"), + ) + ], + ) + for r in h.reasks # type: ignore + ] + if h.reasks != UNSET + else [] + ), + ), + ) + for h in history_events + ] + call_log.iterations.extend(iterations) + if self.history.length == 0: + self.history.push(call_log) + + def _single_server_call( + self, + *, + payload: Dict[str, Any], + llm_output: Optional[str] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + metadata: Optional[Dict] = {}, + full_schema_reask: Optional[bool] = True, + call_log: Optional[Call], + stream: Optional[bool] = False, + ) -> ValidationOutcome[OT]: if self._api_client: - payload: Dict[str, Any] = {"args": list(args)} - payload.update(**kwargs) - if llm_output is not None: - payload["llmOutput"] = llm_output - if num_reasks is not None: - payload["numReasks"] = num_reasks - if prompt_params is not None: - payload["promptParams"] = prompt_params - if llm_api is not None: - payload["llmApi"] = get_llm_api_enum(llm_api) - # TODO: get enum for llm_api - validation_output: Optional[ValidationOutput] = self._api_client.validate( + validation_output: ValidationOutput = self._api_client.validate( guard=self, # type: ignore payload=ValidatePayload.from_dict(payload), openai_api_key=get_call_kwarg("api_key"), ) - if not validation_output: return ValidationOutcome[OT]( raw_llm_output=None, @@ -1457,86 +1545,16 @@ def _call_server( validation_passed=False, error="The response from the server was empty!", ) - - call_log = call_log or Call() - if llm_api is not None: - llm_api = get_llm_ask(llm_api) - if asyncio.iscoroutinefunction(llm_api): - llm_api = get_async_llm_ask(llm_api) - session_history = ( - validation_output.session_history - if validation_output is not None and validation_output.session_history - else [] + self._construct_history_from_server_response( + validation_output=validation_output, + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + stream=stream, ) - history: History - for history in session_history: - history_events: Optional[List[HistoryEvent]] = ( # type: ignore - history.history if history.history != UNSET else None - ) - if history_events is None: - continue - - iterations = [ - Iteration( - inputs=Inputs( - llm_api=llm_api, - llm_output=llm_output, - instructions=( - Instructions(h.instructions) if h.instructions else None - ), - prompt=( - Prompt(h.prompt.source) # type: ignore - if h.prompt is not None and h.prompt != UNSET - else None - ), - prompt_params=prompt_params, - num_reasks=(num_reasks or 0), - metadata=metadata, - full_schema_reask=full_schema_reask, - ), - outputs=Outputs( - llm_response_info=LLMResponse( - output=h.output # type: ignore - ), - raw_output=h.output, - parsed_output=( - h.parsed_output.to_dict() - if isinstance(h.parsed_output, AnyObject) - else h.parsed_output - ), - validation_output=( - h.validated_output.to_dict() - if isinstance(h.validated_output, AnyObject) - else h.validated_output - ), - reasks=list( - [ - FieldReAsk( - incorrect_value=r.to_dict().get( - "incorrect_value" - ), - path=r.to_dict().get("path"), - fail_results=[ - FailResult( - error_message=r.to_dict().get( - "error_message" - ), - fix_value=r.to_dict().get("fix_value"), - ) - ], - ) - for r in h.reasks # type: ignore - ] - if h.reasks != UNSET - else [] - ), - ), - ) - for h in history_events - ] - call_log.iterations.extend(iterations) - if self.history.length == 0: - self.history.push(call_log) # Our interfaces are too different for this to work right now. # Once we move towards shared interfaces for both the open source @@ -1550,6 +1568,120 @@ def _call_server( else: raise ValueError("Guard does not have an api client!") + def _stream_server_call( + self, + *, + payload: Dict[str, Any], + llm_output: Optional[str] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + metadata: Optional[Dict] = {}, + full_schema_reask: Optional[bool] = True, + call_log: Optional[Call], + stream: Optional[bool] = False, + ) -> Generator[ValidationOutcome[OT], None, None]: + if self._api_client: + validation_output: Optional[ValidationOutput] = None + response = self._api_client.stream_validate( + guard=self, # type: ignore + payload=ValidatePayload.from_dict(payload), + openai_api_key=get_call_kwarg("api_key"), + ) + for fragment in response: + validation_output = fragment + if not validation_output: + yield ValidationOutcome[OT]( + raw_llm_output=None, + validated_output=None, + validation_passed=False, + error="The response from the server was empty!", + ) + yield ValidationOutcome[OT]( + raw_llm_output=validation_output.raw_llm_response, # type: ignore + validated_output=cast(OT, validation_output.validated_output), + validation_passed=validation_output.result, + ) + if validation_output: + self._construct_history_from_server_response( + validation_output=validation_output, + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + stream=stream, + ) + else: + raise ValueError("Guard does not have an api client!") + + def _call_server( + self, + *args, + llm_output: Optional[str] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + metadata: Optional[Dict] = {}, + full_schema_reask: Optional[bool] = True, + call_log: Optional[Call], + **kwargs, + ) -> Union[ValidationOutcome[OT], Generator[ValidationOutcome[OT], None, None]]: + if self._api_client: + payload: Dict[str, Any] = {"args": list(args)} + payload.update(**kwargs) + if metadata: + payload["metadata"] = extract_serializeable_metadata(metadata) + if llm_output is not None: + payload["llmOutput"] = llm_output + if num_reasks is not None: + payload["numReasks"] = num_reasks + if prompt_params is not None: + payload["promptParams"] = prompt_params + if llm_api is not None: + payload["llmApi"] = get_llm_api_enum(llm_api, *args, **kwargs) + + should_stream = kwargs.get("stream", False) + if should_stream: + return self._stream_server_call( + payload=payload, + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + stream=should_stream, + ) + else: + return self._single_server_call( + payload=payload, + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + stream=should_stream, + ) + else: + raise ValueError("Guard does not have an api client!") + + def _save(self): + api_key = os.environ.get("GUARDRAILS_API_KEY") + if api_key is not None: + if self.name is None: + self.name = f"gr-{str(self._guard_id)}" + logger.warn("Warning: No name passed to guard!") + logger.warn( + "Use this auto-generated name to re-use this guard: {name}".format( + name=self.name + ) + ) + if not self._api_client: + self._api_client = GuardrailsApiClient(api_key=api_key) + self.upsert_guard() + def to_runnable(self) -> Runnable: from guardrails.integrations.langchain.guard_runnable import GuardRunnable diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index 9bea2d850..a5b224330 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -781,6 +781,7 @@ async def invoke_llm( api_key = None aclient = AsyncOpenAIClient(api_key=api_key) + # FIXME: OpenAI async streaming seems to be broken return await aclient.create_chat_completion( model=model, messages=chat_prompt( @@ -831,6 +832,14 @@ async def invoke_llm( *args, **kwargs, ) + if kwargs.get("stream", False): + # If stream is defined and set to True, + # the callable returns a generator object + # response = cast(AsyncIterable[str], response) + return LLMResponse( + output="", + async_stream_output=response.completion_stream, # pyright: ignore[reportGeneralTypeIssues] + ) return LLMResponse( output=response.choices[0].message.content, # type: ignore @@ -872,6 +881,10 @@ async def invoke_llm( *args, **kwargs, ) + if kwargs.get("stream", False): + raise NotImplementedError( + "Manifest async streaming is not yet supported by manifest." + ) return LLMResponse( output=manifest_response[0], ) @@ -895,6 +908,13 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse: ``` """ output = await self.llm_api(*args, **kwargs) + if kwargs.get("stream", False): + # If stream is defined and set to True, + # the callable returns a generator object + return LLMResponse( + output="", + async_stream_output=output.completion_stream, + ) return LLMResponse( output=output, ) @@ -939,16 +959,20 @@ def model_is_supported_server_side( model = get_llm_ask(llm_api, *args, **kwargs) if asyncio.iscoroutinefunction(llm_api): model = get_async_llm_ask(llm_api, *args, **kwargs) - return issubclass(type(model), OpenAIModel) or issubclass( - type(model), AsyncOpenAIModel + return ( + issubclass(type(model), OpenAIModel) + or issubclass(type(model), AsyncOpenAIModel) + or isinstance(model, LiteLLMCallable) + or isinstance(model, AsyncLiteLLMCallable) ) # FIXME: Update with newly supported LLMs def get_llm_api_enum( - llm_api: Callable[[Any], Awaitable[Any]], + llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs ) -> Optional[ValidatePayloadLlmApi]: # TODO: Distinguish between v1 and v2 + model = get_llm_ask(llm_api, *args, **kwargs) if llm_api == get_static_openai_create_func(): return ValidatePayloadLlmApi.OPENAI_COMPLETION_CREATE elif llm_api == get_static_openai_chat_create_func(): @@ -957,5 +981,10 @@ def get_llm_api_enum( return ValidatePayloadLlmApi.OPENAI_COMPLETION_ACREATE elif llm_api == get_static_openai_chat_acreate_func(): return ValidatePayloadLlmApi.OPENAI_CHATCOMPLETION_ACREATE + elif isinstance(model, LiteLLMCallable): + return ValidatePayloadLlmApi.LITELLM_COMPLETION + elif isinstance(model, AsyncLiteLLMCallable): + return ValidatePayloadLlmApi.LITELLM_ACOMPLETION + else: return None diff --git a/guardrails/run/__init__.py b/guardrails/run/__init__.py index e777106f7..9893f2124 100644 --- a/guardrails/run/__init__.py +++ b/guardrails/run/__init__.py @@ -1,12 +1,14 @@ from guardrails.run.async_runner import AsyncRunner from guardrails.run.runner import Runner from guardrails.run.stream_runner import StreamRunner +from guardrails.run.async_stream_runner import AsyncStreamRunner from guardrails.run.utils import msg_history_source, msg_history_string __all__ = [ "Runner", "AsyncRunner", "StreamRunner", + "AsyncStreamRunner", "msg_history_source", "msg_history_string", ] diff --git a/guardrails/run/async_runner.py b/guardrails/run/async_runner.py index 95b00d9c2..762cfa0d7 100644 --- a/guardrails/run/async_runner.py +++ b/guardrails/run/async_runner.py @@ -144,7 +144,8 @@ async def async_run( return call_log - @async_trace(name="step") + # TODO: Do we want to revert this name to step? + @async_trace(name="async_step") async def async_step( self, index: int, @@ -281,7 +282,6 @@ async def async_call( llm_response = await api_fn(prompt.source) else: raise ValueError("'output', 'prompt' or 'msg_history' must be provided.") - return llm_response async def async_validate( diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py new file mode 100644 index 000000000..fc1fd4018 --- /dev/null +++ b/guardrails/run/async_stream_runner.py @@ -0,0 +1,518 @@ +import copy +from functools import partial +from typing import ( + Any, + AsyncIterable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + Awaitable, +) + +from pydantic import BaseModel + +from guardrails.classes import ValidationOutcome +from guardrails.classes.history import Call, Inputs, Iteration, Outputs +from guardrails.constants import pass_status +from guardrails.datatypes import verify_metadata_requirements +from guardrails.errors import ValidationError +from guardrails.llm_providers import ( + AsyncLiteLLMCallable, + AsyncPromptCallableBase, + LiteLLMCallable, + OpenAICallable, + OpenAIChatCallable, + PromptCallableBase, +) +from guardrails.prompt import Instructions, Prompt +from guardrails.run import StreamRunner +from guardrails.run.utils import msg_history_source, msg_history_string +from guardrails.schema import Schema, StringSchema +from guardrails.utils.llm_response import LLMResponse +from guardrails.utils.openai_utils import OPENAI_VERSION +from guardrails.utils.reask_utils import ReAsk, SkeletonReAsk +from guardrails.utils.telemetry_utils import async_trace +from guardrails.validator_base import ValidationResult + + +class AsyncStreamRunner(StreamRunner): + def __init__( + self, + output_schema: Schema, + num_reasks: int, + prompt: Optional[Union[str, Prompt]] = None, + instructions: Optional[Union[str, Instructions]] = None, + msg_history: Optional[List[Dict]] = None, + api: Optional[AsyncPromptCallableBase] = None, + prompt_schema: Optional[StringSchema] = None, + instructions_schema: Optional[StringSchema] = None, + msg_history_schema: Optional[StringSchema] = None, + metadata: Optional[Dict[str, Any]] = None, + output: Optional[str] = None, + base_model: Optional[ + Union[Type[BaseModel], Type[List[Type[BaseModel]]]] + ] = None, + full_schema_reask: bool = False, + disable_tracer: Optional[bool] = True, + ): + super().__init__( + output_schema=output_schema, + num_reasks=num_reasks, + prompt=prompt, + instructions=instructions, + msg_history=msg_history, + api=api, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=msg_history_schema, + metadata=metadata, + output=output, + base_model=base_model, + full_schema_reask=full_schema_reask, + disable_tracer=disable_tracer, + ) + self.api: Optional[AsyncPromptCallableBase] = api + + async def async_run( + self, call_log: Call, prompt_params: Optional[Dict] = None + ) -> AsyncIterable[ValidationOutcome]: + if prompt_params is None: + prompt_params = {} + + missing_keys = verify_metadata_requirements( + self.metadata, self.output_schema.root_datatype + ) + + if missing_keys: + raise ValueError( + f"Missing required metadata keys: {', '.join(missing_keys)}" + ) + + ( + instructions, + prompt, + msg_history, + prompt_schema, + instructions_schema, + msg_history_schema, + output_schema, + ) = ( + self.instructions, + self.prompt, + self.msg_history, + self.prompt_schema, + self.instructions_schema, + self.msg_history_schema, + self.output_schema, + ) + + result = self.async_step( + index=0, + api=self.api, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=msg_history_schema, + output_schema=output_schema, + output=self.output, + call_log=call_log, + ) + # FIXME: Where can this be moved to be less verbose? This is an await call on + # the async generator. + async for call in result: + yield call + + # @async_trace(name="step") + async def async_step( + self, + index: int, + api: Optional[AsyncPromptCallableBase], + instructions: Optional[Instructions], + prompt: Optional[Prompt], + msg_history: Optional[List[Dict]], + prompt_params: Dict, + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], + output_schema: Schema, + call_log: Call, + output: Optional[str] = None, + ) -> AsyncIterable[ValidationOutcome]: + inputs = Inputs( + llm_api=api, + llm_output=output, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=self.metadata, + full_schema_reask=self.full_schema_reask, + stream=True, + ) + outputs = Outputs() + iteration = Iteration(inputs=inputs, outputs=outputs) + call_log.iterations.push(iteration) + if output: + instructions = None + prompt = None + msg_history = None + else: + instructions, prompt, msg_history = await self.async_prepare( + call_log, + index, + instructions, + prompt, + msg_history, + prompt_params, + api, + prompt_schema, + instructions_schema, + msg_history_schema, + output_schema, + ) + + iteration.inputs.prompt = prompt + iteration.inputs.instructions = instructions + iteration.inputs.msg_history = msg_history + + llm_response = await self.async_call( + index, instructions, prompt, msg_history, api, output + ) + stream_output = llm_response.async_stream_output + if not stream_output: + raise ValueError( + "No async stream was returned from the API. Please check that " + "the API is returning an async generator." + ) + + fragment = "" + parsed_fragment, validated_fragment, valid_op = None, None, None + verified = set() + + if isinstance(output_schema, StringSchema): + async for chunk in stream_output: + chunk_text = self.get_chunk_text(chunk, api) + _ = self.is_last_chunk(chunk, api) + fragment += chunk_text + + parsed_chunk, move_to_next = self.parse( + index, chunk_text, output_schema, verified + ) + if move_to_next: + continue + validated_fragment = await self.async_validate( + iteration, + index, + parsed_chunk, + output_schema, + validate_subschema=True, + stream=True, + ) + if isinstance(validated_fragment, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) + + reasks, valid_op = await self.introspect( + index, validated_fragment, output_schema + ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + passed = call_log.status == pass_status + yield ValidationOutcome( + raw_llm_output=chunk_text, + validated_output=validated_fragment, + validation_passed=passed, + ) + else: + async for chunk in stream_output: + chunk_text = self.get_chunk_text(chunk, api) + fragment += chunk_text + + parsed_fragment, move_to_next = self.parse( + index, fragment, output_schema, verified + ) + if move_to_next: + continue + validated_fragment = await self.async_validate( + iteration, + index, + parsed_fragment, + output_schema, + validate_subschema=True, + ) + if isinstance(validated_fragment, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) + + reasks, valid_op = await self.introspect( + index, validated_fragment, output_schema + ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + + yield ValidationOutcome( + raw_llm_output=fragment, + validated_output=chunk_text, + validation_passed=validated_fragment is not None, + ) + + iteration.outputs.raw_output = fragment + iteration.outputs.parsed_output = parsed_fragment + iteration.outputs.guarded_output = valid_op + iteration.outputs.validation_response = ( + cast(str, validated_fragment) if validated_fragment else None + ) + + @async_trace(name="call") + async def async_call( + self, + index: int, + instructions: Optional[Instructions], + prompt: Optional[Prompt], + msg_history: Optional[List[Dict]], + api: Optional[AsyncPromptCallableBase], + output: Optional[str] = None, + ) -> LLMResponse: + api_fn = api + if api is not None: + supports_base_model = getattr(api, "supports_base_model", False) + if supports_base_model: + api_fn = partial(api, base_model=self.base_model) + if output is not None: + llm_response = LLMResponse( + output=output, + ) + elif api_fn is None: + raise ValueError("Either API or output must be provided.") + elif msg_history: + llm_response = await api_fn(msg_history=msg_history_source(msg_history)) + elif prompt and instructions: + llm_response = await api_fn(prompt.source, instructions=instructions.source) + elif prompt: + llm_response = await api_fn(prompt.source) + else: + raise ValueError("'output', 'prompt' or 'msg_history' must be provided.") + return llm_response + + async def async_validate( + self, + iteration: Iteration, + index: int, + parsed_output: Any, + output_schema: Schema, + validate_subschema: bool = False, + stream: Optional[bool] = False, + ) -> Optional[Union[Awaitable[ValidationResult], ValidationResult]]: + # FIXME: Subschema is currently broken, it always returns a string from async + # streaming. + # Should return None/empty if fail result? + _ = await output_schema.async_validate( + iteration, parsed_output, self.metadata, attempt_number=index, stream=stream + ) + try: + return iteration.outputs.validator_logs[-1].validation_result + except IndexError: + return None + + async def introspect( + self, + index: int, + validated_output: Any, + output_schema: Schema, + ) -> Tuple[Sequence[ReAsk], Any]: + # Introspect: inspect validated output for reasks. + if validated_output is None: + return [], None + reasks, valid_output = output_schema.introspect(validated_output) + + return reasks, valid_output + + async def async_prepare( + self, + call_log: Call, + index: int, + instructions: Optional[Instructions], + prompt: Optional[Prompt], + msg_history: Optional[List[Dict]], + prompt_params: Dict, + api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], + output_schema: Schema, + ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: + if api is None: + raise ValueError("API must be provided.") + + if prompt_params is None: + prompt_params = {} + + if msg_history: + msg_history = copy.deepcopy(msg_history) + for msg in msg_history: + msg["content"] = msg["content"].format(**prompt_params) + + prompt, instructions = None, None + + if msg_history_schema is not None: + msg_str = msg_history_string(msg_history) + inputs = Inputs( + llm_output=msg_str, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_msg_history = await msg_history_schema.async_validate( + iteration, msg_str, self.metadata + ) + if isinstance(validated_msg_history, ReAsk): + raise ValidationError( + f"Message history validation failed: " + f"{validated_msg_history}" + ) + if validated_msg_history != msg_str: + raise ValidationError("Message history validation failed") + elif prompt is not None: + if isinstance(prompt, str): + prompt = Prompt(prompt) + + prompt = prompt.format(**prompt_params) + + if instructions is not None and isinstance(instructions, Instructions): + instructions = instructions.format(**prompt_params) + + instructions, prompt = output_schema.preprocess_prompt( + api, instructions, prompt + ) + + if prompt_schema is not None and prompt is not None: + inputs = Inputs( + llm_output=prompt.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_prompt = await prompt_schema.async_validate( + iteration, prompt.source, self.metadata + ) + iteration.outputs.validation_response = validated_prompt + if validated_prompt is None: + raise ValidationError("Prompt validation failed") + if isinstance(validated_prompt, ReAsk): + raise ValidationError( + f"Prompt validation failed: {validated_prompt}" + ) + prompt = Prompt(validated_prompt) + + if instructions_schema is not None and instructions is not None: + inputs = Inputs( + llm_output=instructions.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_instructions = await instructions_schema.async_validate( + iteration, instructions.source, self.metadata + ) + iteration.outputs.validation_response = validated_instructions + if validated_instructions is None: + raise ValidationError("Instructions validation failed") + if isinstance(validated_instructions, ReAsk): + raise ValidationError( + f"Instructions validation failed: {validated_instructions}" + ) + instructions = Instructions(validated_instructions) + else: + raise ValueError("Prompt or message history must be provided.") + + return instructions, prompt, msg_history + + def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: + """Get the text from a chunk.""" + chunk_text = "" + if isinstance(api, OpenAICallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + if "text" in chunk["choices"][0]: + content = chunk["choices"][0]["text"] + if not finished and content: + chunk_text = content + else: + finished = chunk.choices[0].finish_reason + content = chunk.choices[0].text + if not finished and content: + chunk_text = content + elif isinstance(api, OpenAIChatCallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + if "content" in chunk["choices"][0]["delta"]: + content = chunk["choices"][0]["delta"]["content"] + if not finished and content: + chunk_text = content + else: + finished = chunk.choices[0].finish_reason + content = chunk.choices[0].delta.content + if not finished and content: + chunk_text = content + elif isinstance(api, LiteLLMCallable): + finished = chunk.choices[0].finish_reason + content = chunk.choices[0].delta.content + if not finished and content: + chunk_text = content + elif isinstance(api, AsyncLiteLLMCallable): + finished = chunk.choices[0].finish_reason + content = chunk.choices[0].delta.content + if not finished and content: + chunk_text = content + else: + try: + chunk_text = chunk + except Exception as e: + raise ValueError( + f"Error getting chunk from stream: {e}. " + "Non-OpenAI API callables expected to return " + "a generator of strings." + ) from e + return chunk_text + + def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool: + """Detect if chunk is final chunk.""" + if isinstance(api, OpenAICallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, OpenAIChatCallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, LiteLLMCallable): + finished = chunk.choices[0].finish_reason + return finished is not None + else: + try: + finished = chunk.choices[0].finish_reason + return finished is not None + except (AttributeError, TypeError): + return False diff --git a/guardrails/run/runner.py b/guardrails/run/runner.py index 9688d43ea..eba815973 100644 --- a/guardrails/run/runner.py +++ b/guardrails/run/runner.py @@ -547,17 +547,29 @@ def validate( index: int, parsed_output: Any, output_schema: Schema, + stream: Optional[bool] = False, **kwargs, ): """Validate the output.""" - validated_output = output_schema.validate( - iteration, - parsed_output, - self.metadata, - attempt_number=index, - disable_tracer=self._disable_tracer, - **kwargs, - ) + if isinstance(output_schema, StringSchema): + validated_output = output_schema.validate( + iteration, + parsed_output, + self.metadata, + index, + self._disable_tracer, + stream, + **kwargs, + ) + else: + validated_output = output_schema.validate( + iteration, + parsed_output, + self.metadata, + attempt_number=index, + disable_tracer=self._disable_tracer, + **kwargs, + ) return validated_output diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 890c3fd9c..56e7873eb 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -15,6 +15,7 @@ from guardrails.schema import Schema, StringSchema from guardrails.utils.openai_utils import OPENAI_VERSION from guardrails.utils.reask_utils import SkeletonReAsk +from guardrails.constants import pass_status class StreamRunner(Runner): @@ -108,6 +109,7 @@ def step( num_reasks=self.num_reasks, metadata=self.metadata, full_schema_reask=self.full_schema_reask, + stream=True, ) outputs = Outputs() iteration = Iteration(inputs=inputs, outputs=outputs) @@ -153,47 +155,118 @@ def step( verified = set() # Loop over the stream # and construct "fragments" of concatenated chunks - for chunk in stream: - # 1. Get the text from the chunk and append to fragment - chunk_text = self.get_chunk_text(chunk, api) - fragment += chunk_text + # for now, handle string and json schema differently - # 2. Parse the fragment - parsed_fragment, move_to_next = self.parse( - index, fragment, output_schema, verified - ) - if move_to_next: - # Continue to next chunk - continue + if isinstance(output_schema, StringSchema): + stream_finished = False + last_chunk_text = "" + for chunk in stream: + # 1. Get the text from the chunk and append to fragment + chunk_text = self.get_chunk_text(chunk, api) + last_chunk_text = chunk_text + finished = self.is_last_chunk(chunk, api) + if finished: + stream_finished = True + fragment += chunk_text - # 3. Run output validation - validated_fragment = self.validate( - iteration, - index, - parsed_fragment, - output_schema, - validate_subschema=True, - ) - if isinstance(validated_fragment, SkeletonReAsk): - raise ValueError( - "Received fragment schema is an invalid sub-schema " - "of the expected output JSON schema." + # 2. Parse the chunk + parsed_chunk, move_to_next = self.parse( + index, chunk_text, output_schema, verified + ) + if move_to_next: + # Continue to next chunk + continue + validated_text = self.validate( + iteration, + index, + parsed_chunk, + output_schema, + True, + validate_subschema=True, + # if it is the last chunk, validate everything that's left + remainder=finished, ) + if isinstance(validated_text, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) - # 4. Introspect: inspect the validated fragment for reasks - reasks, valid_op = self.introspect(index, validated_fragment, output_schema) - if reasks: - raise ValueError( - "Reasks are not yet supported with streaming. Please " - "remove reasks from schema or disable streaming." + # 4. Introspect: inspect the validated fragment for reasks + reasks, valid_op = self.introspect(index, validated_text, output_schema) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + # 5. Convert validated fragment to a pretty JSON string + passed = call_log.status == pass_status + yield ValidationOutcome( + # The chunk or the whole output? + raw_llm_output=chunk_text, + validated_output=validated_text, + validation_passed=passed, + ) + # handle case where generator doesn't give finished status + if not stream_finished: + last_result = self.validate( + iteration, + index, + "", + output_schema, + True, + validate_subschema=True, + remainder=True, ) + if len(last_result) > 0: + passed = call_log.status == pass_status + yield ValidationOutcome( + raw_llm_output=last_chunk_text, + validated_output=last_result, + validation_passed=passed, + ) + # handle non string schema + else: + for chunk in stream: + # 1. Get the text from the chunk and append to fragment + chunk_text = self.get_chunk_text(chunk, api) + fragment += chunk_text - # 5. Convert validated fragment to a pretty JSON string - yield ValidationOutcome( - raw_llm_output=fragment, - validated_output=validated_fragment, - validation_passed=validated_fragment is not None, - ) + parsed_fragment, move_to_next = self.parse( + index, fragment, output_schema, verified + ) + if move_to_next: + # Continue to next chunk + continue + validated_fragment = self.validate( + iteration, + index, + parsed_fragment, + output_schema, + validate_subschema=True, + ) + if isinstance(validated_fragment, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) + + # 4. Introspect: inspect the validated fragment for reasks + reasks, valid_op = self.introspect( + index, validated_fragment, output_schema + ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + + # 5. Convert validated fragment to a pretty JSON string + yield ValidationOutcome( + raw_llm_output=fragment, + validated_output=validated_fragment, + validation_passed=validated_fragment is not None, + ) # Finally, add to logs iteration.outputs.raw_output = fragment @@ -201,6 +274,32 @@ def step( iteration.outputs.validation_response = validated_fragment iteration.outputs.guarded_output = valid_op + def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool: + """Detect if chunk is final chunk.""" + if isinstance(api, OpenAICallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, OpenAIChatCallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, LiteLLMCallable): + finished = chunk.choices[0].finish_reason + return finished is not None + else: + try: + finished = chunk.choices[0].finish_reason + return finished is not None + except (AttributeError, TypeError): + return False + def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: """Get the text from a chunk.""" chunk_text = "" diff --git a/guardrails/schema/schema.py b/guardrails/schema/schema.py index 7d8dad8c0..6b6eb33cf 100644 --- a/guardrails/schema/schema.py +++ b/guardrails/schema/schema.py @@ -87,7 +87,13 @@ def validate( raise NotImplementedError async def async_validate( - self, iteration: Iteration, data: Any, metadata: Dict, attempt_number: int = 0 + self, + iteration: Iteration, + data: Any, + metadata: Dict, + attempt_number: int = 0, + stream: Optional[bool] = False, + **kwargs, ) -> Any: """Asynchronously validate a dictionary of data against the schema. diff --git a/guardrails/schema/string_schema.py b/guardrails/schema/string_schema.py index cbff5d5d1..4b9e76634 100644 --- a/guardrails/schema/string_schema.py +++ b/guardrails/schema/string_schema.py @@ -134,6 +134,7 @@ def validate( metadata: Dict, attempt_number: int = 0, disable_tracer: Optional[bool] = True, + stream: Optional[bool] = False, **kwargs, ) -> Any: """Validate a dictionary of data against the schema. @@ -160,19 +161,20 @@ def validate( dummy_key: data, }, ) - validated_response, metadata = validator_service.validate( value=data, metadata=metadata, validator_setup=validation, iteration=iteration, disable_tracer=disable_tracer, + stream=stream, + **kwargs, ) validated_response = {dummy_key: validated_response} if check_refrain_in_dict(validated_response): - # If the data contains a `Refain` value, we return an empty + # If the data contains a `Refrain` value, we return an empty # dictionary. logger.debug("Refrain detected.") validated_response = {} @@ -194,6 +196,7 @@ async def async_validate( data: Any, metadata: Dict, attempt_number: int = 0, + stream: Optional[bool] = False, ) -> Any: """Validate a dictionary of data against the schema. @@ -223,6 +226,7 @@ async def async_validate( metadata=metadata, validator_setup=validation, iteration=iteration, + stream=stream, ) validated_response = {dummy_key: validated_response} diff --git a/guardrails/utils/api_utils.py b/guardrails/utils/api_utils.py new file mode 100644 index 000000000..8d91f64f2 --- /dev/null +++ b/guardrails/utils/api_utils.py @@ -0,0 +1,16 @@ +import json +from typing import Any, Dict + + +def try_to_json(value: Any): + try: + json.dumps(value) + return True + except ValueError: + return False + except TypeError: + return False + + +def extract_serializeable_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + return {k: metadata[k] for k in metadata if try_to_json(metadata[k])} diff --git a/guardrails/utils/hub_telemetry_utils.py b/guardrails/utils/hub_telemetry_utils.py index 719f80050..6d9517130 100644 --- a/guardrails/utils/hub_telemetry_utils.py +++ b/guardrails/utils/hub_telemetry_utils.py @@ -115,7 +115,7 @@ def create_new_span( if self._tracer is None: return with self._tracer.start_as_current_span( - span_name, + span_name, # type: ignore (Fails in Python 3.8 for invalid reason) context=self.extract_current_context() if has_parent else None, ) as span: if is_parent: diff --git a/guardrails/utils/llm_response.py b/guardrails/utils/llm_response.py index f0248fe0d..b797f5b5a 100644 --- a/guardrails/utils/llm_response.py +++ b/guardrails/utils/llm_response.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional +from typing import Iterable, Optional, AsyncIterable from guardrails.utils.pydantic_utils import ArbitraryModel @@ -8,3 +8,4 @@ class LLMResponse(ArbitraryModel): response_token_count: Optional[int] = None output: str stream_output: Optional[Iterable] = None + async_stream_output: Optional[AsyncIterable] = None diff --git a/guardrails/utils/telemetry_utils.py b/guardrails/utils/telemetry_utils.py index 599d013b6..35c453491 100644 --- a/guardrails/utils/telemetry_utils.py +++ b/guardrails/utils/telemetry_utils.py @@ -156,7 +156,8 @@ def with_trace(*args, **kwargs): if _tracer is None: return fn(*args, **kwargs) with _tracer.start_as_current_span( - span_name, trace_context + span_name, # type: ignore (Fails in Python 3.8 for invalid reason) + trace_context, ) as validator_span: try: validator_span.set_attribute( @@ -200,7 +201,7 @@ def to_trace_or_not_to_trace(*args, **kwargs): if _tracer is not None and hasattr(_tracer, "start_as_current_span"): trace_context = get_current_context() - with _tracer.start_as_current_span(name, trace_context) as trace_span: + with _tracer.start_as_current_span(name, trace_context) as trace_span: # type: ignore (Fails in Python 3.8 for invalid reason) try: # TODO: Capture args and kwargs as attributes? response = fn(*args, **kwargs) @@ -226,7 +227,7 @@ async def to_trace_or_not_to_trace(*args, **kwargs): if _tracer is not None and hasattr(_tracer, "start_as_current_span"): trace_context = get_current_context() - with _tracer.start_as_current_span(name, trace_context) as trace_span: + with _tracer.start_as_current_span(name, trace_context) as trace_span: # type: ignore (Fails in Python 3.8 for invalid reason) try: # TODO: Capture args and kwargs as attributes? response = await fn(*args, **kwargs) diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index fa01573f1..1a636d6c5 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -1,4 +1,5 @@ import inspect +import nltk from collections import defaultdict from copy import deepcopy from enum import Enum @@ -176,6 +177,37 @@ class Refrain: pass +# functions to get chunks + + +def split_sentence_str(chunk: str): + """A naive sentence splitter that splits on periods.""" + if "." not in chunk: + return [] + fragments = chunk.split(".") + return [fragments[0] + ".", ".".join(fragments[1:])] + + +def split_sentence_nltk(chunk: str): + """ + NOTE: this approach currently does not work + Use a sentence tokenizer to split the chunk into sentences. + + Because using the tokenizer is expensive, we only use it if there + is a period present in the chunk. + """ + # using the sentence tokenizer is expensive + # we check for a . to avoid wastefully calling the tokenizer + if "." not in chunk: + return [] + sentences = nltk.sent_tokenize(chunk) + if len(sentences) == 0: + return [] + # return the sentence + # then the remaining chunks that aren't finished accumulating + return [sentences[0], "".join(sentences[1:])] + + def check_refrain_in_list(schema: List) -> bool: """Checks if a Refrain object exists in a list. @@ -356,6 +388,9 @@ def get_validator(name: str): class ValidationResult(BaseModel): outcome: str metadata: Optional[Dict[str, Any]] = None + # value argument passed to validator.validate + # or validator.validate_stream + validated_chunk: Optional[Any] = None class PassResult(ValidationResult): @@ -368,11 +403,21 @@ class ValueOverrideSentinel: value_override: Optional[Any] = Field(default=ValueOverrideSentinel) +# specifies the start and end of segment of validate_chunk +class ErrorSpan(BaseModel): + start: int + end: int + # reason validation failed, specific to this chunk + reason: str + + class FailResult(ValidationResult): outcome: Literal["fail"] = "fail" error_message: str fix_value: Optional[Any] = None + # segments that caused validation to fail + error_spans: Optional[List[ErrorSpan]] = None class OnFailAction(str, Enum): @@ -391,6 +436,10 @@ class Validator(Runnable): rail_alias: str = "" + # chunking function returns empty list or list of 2 chunks + # first chunk is the chunk to validate + # second chunk is incomplete chunk that needs further accumulation + accumulated_chunks = [] run_in_separate_process = False override_value_on_pass = False required_metadata_keys = [] @@ -449,10 +498,49 @@ def __init__( self.rail_alias in validators_registry ), f"Validator {self.__class__.__name__} is not registered. " + def chunking_function(self, chunk: str): + return split_sentence_str(chunk) + def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: """Validates a value and return a validation result.""" raise NotImplementedError + def validate_stream( + self, chunk: Any, metadata: Dict[str, Any], **kwargs + ) -> Optional[ValidationResult]: + """Validates a chunk emitted by an LLM. If the LLM chunk is smaller + than the validator's chunking strategy, it will be accumulated until it + reaches the desired size. In the meantime, the validator will return + None. + + If the LLM chunk is larger than the validator's chunking + strategy, it will split it into validator-sized chunks and + validate each one, returning an array of validation results. + + Otherwise, the validator will validate the chunk and return the + result. + """ + # combine accumulated chunks and new [:-1]chunk + self.accumulated_chunks.append(chunk) + accumulated_text = "".join(self.accumulated_chunks) + # check if enough chunks have accumulated for validation + splitcontents = self.chunking_function(accumulated_text) + + # if remainder kwargs is passed, validate remainder regardless + remainder = kwargs.get("remainder", False) + if remainder: + splitcontents = [accumulated_text, ""] + if len(splitcontents) == 0: + return PassResult() + [chunk_to_validate, new_accumulated_chunks] = splitcontents + self.accumulated_chunks = [new_accumulated_chunks] + # exclude last chunk, because it may not be a complete chunk + validation_result = self.validate(chunk_to_validate, metadata) + # if validate doesn't set validated chunk, we set it + if validation_result.validated_chunk is None: + validation_result.validated_chunk = chunk_to_validate + return validation_result + def to_prompt(self, with_keywords: bool = True) -> str: """Convert the validator to a prompt. diff --git a/guardrails/validator_service.py b/guardrails/validator_service.py index 92b7879a2..070dbe230 100644 --- a/guardrails/validator_service.py +++ b/guardrails/validator_service.py @@ -3,12 +3,12 @@ import os from concurrent.futures import ProcessPoolExecutor from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union, cast from guardrails.classes.history import Iteration from guardrails.datatypes import FieldValidation from guardrails.errors import ValidationError -from guardrails.logger import logger +from guardrails.utils.exception_utils import UserFacingException from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.logs_utils import ValidatorLogs from guardrails.utils.reask_utils import FieldReAsk, ReAsk @@ -24,6 +24,8 @@ Validator, ) +ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] + def key_not_empty(key: str) -> bool: return key is not None and len(str(key)) > 0 @@ -42,8 +44,14 @@ def __init__(self, disable_tracer: Optional[bool] = True): # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, # but is relatively unsupported on Windows. def execute_validator( - self, validator: Validator, value: Any, metadata: Optional[Dict] - ) -> ValidationResult: + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorResult: + validate_func = validator.validate_stream if stream else validator.validate traced_validator = trace_validator( validator_name=validator.rail_alias, obj_id=id(validator), @@ -51,8 +59,8 @@ def execute_validator( # namespace=validator.namespace, on_fail_descriptor=validator.on_fail_descriptor, **validator._kwargs, - )(validator.validate) - result = traced_validator(value, metadata) + )(validate_func) + result = traced_validator(value, metadata, **kwargs) return result def perform_correction( @@ -61,6 +69,7 @@ def perform_correction( value: Any, validator: Validator, on_fail_descriptor: Union[OnFailAction, str], + rechecked_value: Optional[ValidationResult] = None, ): if on_fail_descriptor == OnFailAction.FIX: # FIXME: Should we still return fix_value if it is None? @@ -69,11 +78,8 @@ def perform_correction( elif on_fail_descriptor == OnFailAction.FIX_REASK: # FIXME: Same thing here fixed_value = results[0].fix_value - result = self.execute_validator( - validator, fixed_value, results[0].metadata or {} - ) - if isinstance(result, FailResult): + if isinstance(rechecked_value, FailResult): return FieldReAsk( incorrect_value=fixed_value, fail_results=results, @@ -106,12 +112,11 @@ def perform_correction( f"expected 'fix' or 'exception'." ) - def run_validator( + def before_run_validator( self, iteration: Iteration, validator: Validator, value: Any, - metadata: Dict, property_path: str, ) -> ValidatorLogs: validator_class_name = validator.__class__.__name__ @@ -120,21 +125,26 @@ def run_validator( value_before_validation=value, registered_name=validator.rail_alias, property_path=property_path, + # If we ever re-use validator instances across multiple properties, + # this will have to change. + instance_id=id(validator), ) iteration.outputs.validator_logs.append(validator_logs) start_time = datetime.now() - result = self.execute_validator(validator, value, metadata) - end_time = datetime.now() - if result is None: - result = PassResult() + validator_logs.start_time = start_time + return validator_logs + + def after_run_validator( + self, + validator: Validator, + validator_logs: ValidatorLogs, + result: ValidationResult, + ): + end_time = datetime.now() validator_logs.validation_result = result - validator_logs.start_time = start_time validator_logs.end_time = end_time - # If we ever re-use validator instances across multiple properties, - # this will have to change. - validator_logs.instance_id = id(validator) if not self._disable_tracer: # Get HubTelemetry singleton and create a new span to @@ -145,7 +155,12 @@ def run_validator( attributes=[ ("validator_name", validator.rail_alias), ("validator_on_fail", validator.on_fail_descriptor), - ("validator_result", result.outcome), + ( + "validator_result", + result.outcome + if isinstance(result, ValidationResult) + else None, + ), ], is_parent=False, # This span will have no children has_parent=True, # This span has a parent @@ -153,8 +168,61 @@ def run_validator( return validator_logs + def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorLogs: + raise NotImplementedError + class SequentialValidatorService(ValidatorServiceBase): + def run_validator_sync( + self, + validator: Validator, + value: Any, + metadata: Dict, + validator_logs: ValidatorLogs, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidationResult: + result = self.execute_validator(validator, value, metadata, stream, **kwargs) + if asyncio.iscoroutine(result): + raise UserFacingException( + ValueError( + "Cannot use async validators with a synchronous Guard! " + f"Either use AsyncGuard or remove {validator_logs.validator_name}." + ) + ) + elif result is None: + result = PassResult() + return cast(ValidationResult, result) + + def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorLogs: + validator_logs = self.before_run_validator( + iteration, validator, value, property_path + ) + + result = self.run_validator_sync( + validator, value, metadata, validator_logs, stream, **kwargs + ) + + return self.after_run_validator(validator, validator_logs, result) + def run_validators( self, iteration: Iteration, @@ -162,17 +230,34 @@ def run_validators( value: Any, metadata: Dict[str, Any], property_path: str, + stream: Optional[bool] = False, + **kwargs, ) -> Tuple[Any, Dict[str, Any]]: # Validate the field for validator in validator_setup.validators: validator_logs = self.run_validator( - iteration, validator, value, metadata, property_path + iteration, validator, value, metadata, property_path, stream, **kwargs ) - result = validator_logs.validation_result + result = cast(ValidationResult, result) if isinstance(result, FailResult): + rechecked_value = None + if validator.on_fail_descriptor == OnFailAction.FIX_REASK: + fixed_value = result.fix_value + rechecked_value = self.run_validator_sync( + validator, + fixed_value, + metadata, + validator_logs, + stream, + **kwargs, + ) value = self.perform_correction( - [result], value, validator, validator.on_fail_descriptor + [result], + value, + validator, + validator.on_fail_descriptor, + rechecked_value=rechecked_value, ) elif isinstance(result, PassResult): if ( @@ -180,11 +265,11 @@ def run_validators( and result.value_override is not result.ValueOverrideSentinel ): value = result.value_override - else: + elif not stream: raise RuntimeError(f"Unexpected result type {type(result)}") validator_logs.value_after_validation = value - if result.metadata is not None: + if result and result.metadata is not None: metadata = result.metadata if isinstance(value, (Refrain, Filter, ReAsk)): @@ -229,6 +314,29 @@ def validate( value, metadata = self.run_validators( iteration, validator_setup, value, metadata, property_path ) + return value, metadata + + def validate_stream( + self, + value: Any, + metadata: dict, + validator_setup: FieldValidation, + iteration: Iteration, + path: str = "$", + **kwargs, + ) -> Tuple[Any, dict]: + property_path = ( + f"{path}.{validator_setup.key}" + if key_not_empty(validator_setup.key) + else path + ) + # I assume validate stream doesn't need validate_dependents + # because right now we're only handling StringSchema + + # Validate the field + value, metadata = self.run_validators( + iteration, validator_setup, value, metadata, property_path, True, **kwargs + ) return value, metadata @@ -245,6 +353,46 @@ def __init__(self): class AsyncValidatorService(ValidatorServiceBase, MultiprocMixin): + async def run_validator_async( + self, + validator: Validator, + value: Any, + metadata: Dict, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidationResult: + result: ValidatorResult = self.execute_validator( + validator, value, metadata, stream, **kwargs + ) + if asyncio.iscoroutine(result): + result = await result + + if result is None: + result = PassResult() + else: + result = cast(ValidationResult, result) + return result + + async def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorLogs: + validator_logs = self.before_run_validator( + iteration, validator, value, property_path + ) + + result = await self.run_validator_async( + validator, value, metadata, stream, **kwargs + ) + + return self.after_run_validator(validator, validator_logs, result) + def group_validators(self, validators): groups = itertools.groupby( validators, key=lambda v: (v.on_fail_descriptor, v.override_value_on_pass) @@ -267,6 +415,7 @@ async def run_validators( value: Any, metadata: Dict, property_path: str, + stream: Optional[bool] = False, ): loop = asyncio.get_running_loop() for on_fail, validator_group in self.group_validators( @@ -286,19 +435,30 @@ async def run_validators( value, metadata, property_path, + stream, ) ) else: # run the validators in the current process - result = self.run_validator( - iteration, validator, value, metadata, property_path + result = await self.run_validator( + iteration, + validator, + value, + metadata, + property_path, + stream=stream, ) validators_logs.append(result) # wait for the parallel tasks to finish if parallel_tasks: parallel_results = await asyncio.gather(*parallel_tasks) - validators_logs.extend(parallel_results) + awaited_results = [] + for res in parallel_results: + if asyncio.iscoroutine(res): + res = await res + awaited_results.append(res) + validators_logs.extend(awaited_results) # process the results, handle failures fails = [ @@ -308,8 +468,19 @@ async def run_validators( ] if fails: fail_results = [logs.validation_result for logs in fails] + rechecked_value = None + validator: Validator = validator_group[0] + if validator.on_fail_descriptor == OnFailAction.FIX_REASK: + fixed_value = fail_results[0].fix_value + rechecked_value = await self.run_validator_async( + validator, fixed_value, fail_results[0].metadata or {}, stream + ) value = self.perform_correction( - fail_results, value, validator_group[0], on_fail + fail_results, + value, + validator_group[0], + on_fail, + rechecked_value=rechecked_value, ) # handle overrides @@ -363,6 +534,7 @@ async def async_validate( validator_setup: FieldValidation, iteration: Iteration, path: str = "$", + stream: Optional[bool] = False, ) -> Tuple[Any, dict]: property_path = ( f"{path}.{validator_setup.key}" @@ -377,7 +549,7 @@ async def async_validate( # Validate the field value, metadata = await self.run_validators( - iteration, validator_setup, value, metadata, property_path + iteration, validator_setup, value, metadata, property_path, stream=stream ) return value, metadata @@ -412,27 +584,27 @@ def validate( validator_setup: FieldValidation, iteration: Iteration, disable_tracer: Optional[bool] = True, + stream: Optional[bool] = False, + **kwargs, ): process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - + if stream: + sequential_validator_service = SequentialValidatorService(disable_tracer) + return sequential_validator_service.validate_stream( + value, metadata, validator_setup, iteration, **kwargs + ) try: loop = asyncio.get_event_loop() except RuntimeError: loop = None if process_count == 1: - logger.warning( - "Process count was set to 1 via the GUARDRAILS_PROCESS_COUNT" - "environment variable." - "This will cause all validations to run synchronously." - "To run asynchronously, specify a process count" - "greater than 1 or unset this environment variable." - ) validator_service = SequentialValidatorService(disable_tracer) elif loop is not None and not loop.is_running(): validator_service = AsyncValidatorService(disable_tracer) else: validator_service = SequentialValidatorService(disable_tracer) + return validator_service.validate( value, metadata, @@ -447,11 +619,9 @@ async def async_validate( validator_setup: FieldValidation, iteration: Iteration, disable_tracer: Optional[bool] = True, -): + stream: Optional[bool] = False, +) -> Tuple[Any, dict]: validator_service = AsyncValidatorService(disable_tracer) return await validator_service.async_validate( - value, - metadata, - validator_setup, - iteration, + value, metadata, validator_setup, iteration, "$", stream ) diff --git a/guardrails/validatorsattr.py b/guardrails/validatorsattr.py index fbf5e3f99..9d5468ebf 100644 --- a/guardrails/validatorsattr.py +++ b/guardrails/validatorsattr.py @@ -327,7 +327,10 @@ def get_validators( continue # See if the formatter has an associated on_fail method. - on_fail = on_fail_handlers.get(validator_name, None) + escaped_validator_name = validator_name.replace("/", "_") + on_fail = on_fail_handlers.get( + validator_name, on_fail_handlers.get(escaped_validator_name) + ) # TODO(shreya): Load the on_fail method. # This method should be loaded from an optional script given at the # beginning of a rail file. diff --git a/poetry.lock b/poetry.lock index 228329555..a4cd0ea91 100644 --- a/poetry.lock +++ b/poetry.lock @@ -368,17 +368,6 @@ files = [ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, ] -[[package]] -name = "backoff" -version = "2.2.1" -description = "Function decoration for backoff and retry" -optional = false -python-versions = ">=3.7,<4.0" -files = [ - {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, - {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, -] - [[package]] name = "backports-tarfile" version = "1.1.1" @@ -1741,71 +1730,71 @@ colorama = ">=0.4" [[package]] name = "grpcio" -version = "1.63.0" +version = "1.64.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] + {file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"}, + {file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"}, + {file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"}, + {file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"}, + {file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"}, + {file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"}, + {file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"}, + {file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"}, + {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, + {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, + {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, + {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, + {file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"}, + {file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"}, + {file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"}, + {file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"}, + {file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"}, + {file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"}, + {file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"}, + {file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"}, + {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.64.0)"] [[package]] name = "guardrails-api-client" -version = "0.1.1" +version = "0.2.1" description = "A client library for accessing Guardrails API" optional = false python-versions = "<4,>=3.8" files = [ - {file = "guardrails-api-client-0.1.1.tar.gz", hash = "sha256:d707661b0e63269c9ab257c4ba6eed20af18a734607cab8bb3bee0d2883bc50f"}, - {file = "guardrails_api_client-0.1.1-py3-none-any.whl", hash = "sha256:99baf4a11fcc61b420197c019fa24a479a44afb44b858e43d874cc95926295fd"}, + {file = "guardrails_api_client-0.2.1-py3-none-any.whl", hash = "sha256:e6f70304b498a79c621149ae433041bf298c2268acf622442f2356e0b371e903"}, + {file = "guardrails_api_client-0.2.1.tar.gz", hash = "sha256:68c5e31abffe227c2ec9e2f46e00ab94c7af5692e1c075331145280e0addde99"}, ] [package.dependencies] @@ -2230,7 +2219,7 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.4.2" description = "Lightweight pipelining with Python functions" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, @@ -3808,7 +3797,7 @@ files = [ name = "nltk" version = "3.8.1" description = "Natural Language Toolkit" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, @@ -4142,91 +4131,85 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "opentelemetry-api" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Python API" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_api-1.20.0-py3-none-any.whl", hash = "sha256:982b76036fec0fdaf490ae3dfd9f28c81442a33414f737abc687a32758cdcba5"}, - {file = "opentelemetry_api-1.20.0.tar.gz", hash = "sha256:06abe351db7572f8afdd0fb889ce53f3c992dbf6f6262507b385cc1963e06983"}, + {file = "opentelemetry_api-1.24.0-py3-none-any.whl", hash = "sha256:0f2c363d98d10d1ce93330015ca7fd3a65f60be64e05e30f557c61de52c80ca2"}, + {file = "opentelemetry_api-1.24.0.tar.gz", hash = "sha256:42719f10ce7b5a9a73b10a4baf620574fb8ad495a9cbe5c18d76b75d8689c67e"}, ] [package.dependencies] deprecated = ">=1.2.6" -importlib-metadata = ">=6.0,<7.0" +importlib-metadata = ">=6.0,<=7.0" [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Protobuf encoding" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.20.0-py3-none-any.whl", hash = "sha256:dd63209b40702636ab6ae76a06b401b646ad7b008a906ecb41222d4af24fbdef"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.20.0.tar.gz", hash = "sha256:df60c681bd61812e50b3a39a7a1afeeb6d4066117583249fcc262269374e7a49"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.24.0-py3-none-any.whl", hash = "sha256:e51f2c9735054d598ad2df5d3eca830fecfb5b0bda0a2fa742c9c7718e12f641"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.24.0.tar.gz", hash = "sha256:5d31fa1ff976cacc38be1ec4e3279a3f88435c75b38b1f7a099a1faffc302461"}, ] [package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} -opentelemetry-proto = "1.20.0" +opentelemetry-proto = "1.24.0" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.20.0-py3-none-any.whl", hash = "sha256:7c3f066065891b56348ba2c7f9df6ec635a712841cae0a36f2f6a81642ae7dec"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.20.0.tar.gz", hash = "sha256:6c06d43c3771bda1795226e327722b4b980fa1ca1ec9e985f2ef3e29795bdd52"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.24.0-py3-none-any.whl", hash = "sha256:f40d62aa30a0a43cc1657428e59fcf82ad5f7ea8fff75de0f9d9cb6f739e0a3b"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.24.0.tar.gz", hash = "sha256:217c6e30634f2c9797999ea9da29f7300479a94a610139b9df17433f915e7baa"}, ] [package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" grpcio = ">=1.0.0,<2.0.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.20.0" -opentelemetry-proto = "1.20.0" -opentelemetry-sdk = ">=1.20.0,<1.21.0" +opentelemetry-exporter-otlp-proto-common = "1.24.0" +opentelemetry-proto = "1.24.0" +opentelemetry-sdk = ">=1.24.0,<1.25.0" [package.extras] test = ["pytest-grpc"] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_http-1.20.0-py3-none-any.whl", hash = "sha256:03f6e768ad25f1c3a9586e8c695db4a4adf978f8546a1285fa962e16bfbb0bd6"}, - {file = "opentelemetry_exporter_otlp_proto_http-1.20.0.tar.gz", hash = "sha256:500f42821420fdf0759193d6438edc0f4e984a83e14c08a23023c06a188861b4"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.24.0-py3-none-any.whl", hash = "sha256:25af10e46fdf4cd3833175e42f4879a1255fc01655fe14c876183a2903949836"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.24.0.tar.gz", hash = "sha256:704c066cc96f5131881b75c0eac286cd73fc735c490b054838b4513254bd7850"}, ] [package.dependencies] -backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.20.0" -opentelemetry-proto = "1.20.0" -opentelemetry-sdk = ">=1.20.0,<1.21.0" +opentelemetry-exporter-otlp-proto-common = "1.24.0" +opentelemetry-proto = "1.24.0" +opentelemetry-sdk = ">=1.24.0,<1.25.0" requests = ">=2.7,<3.0" -[package.extras] -test = ["responses (==0.22.0)"] - [[package]] name = "opentelemetry-proto" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Python Proto" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_proto-1.20.0-py3-none-any.whl", hash = "sha256:512c3d2c6864fb7547a69577c3907348e6c985b7a204533563cb4c4c5046203b"}, - {file = "opentelemetry_proto-1.20.0.tar.gz", hash = "sha256:cf01f49b3072ee57468bccb1a4f93bdb55411f4512d0ac3f97c5c04c0040b5a2"}, + {file = "opentelemetry_proto-1.24.0-py3-none-any.whl", hash = "sha256:bcb80e1e78a003040db71ccf83f2ad2019273d1e0828089d183b18a1476527ce"}, + {file = "opentelemetry_proto-1.24.0.tar.gz", hash = "sha256:ff551b8ad63c6cabb1845ce217a6709358dfaba0f75ea1fa21a61ceddc78cab8"}, ] [package.dependencies] @@ -4234,29 +4217,29 @@ protobuf = ">=3.19,<5.0" [[package]] name = "opentelemetry-sdk" -version = "1.20.0" +version = "1.24.0" description = "OpenTelemetry Python SDK" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_sdk-1.20.0-py3-none-any.whl", hash = "sha256:f2230c276ff4c63ea09b3cb2e2ac6b1265f90af64e8d16bbf275c81a9ce8e804"}, - {file = "opentelemetry_sdk-1.20.0.tar.gz", hash = "sha256:702e432a457fa717fd2ddfd30640180e69938f85bb7fec3e479f85f61c1843f8"}, + {file = "opentelemetry_sdk-1.24.0-py3-none-any.whl", hash = "sha256:fa731e24efe832e98bcd90902085b359dcfef7d9c9c00eb5b9a18587dae3eb59"}, + {file = "opentelemetry_sdk-1.24.0.tar.gz", hash = "sha256:75bc0563affffa827700e0f4f4a68e1e257db0df13372344aebc6f8a64cde2e5"}, ] [package.dependencies] -opentelemetry-api = "1.20.0" -opentelemetry-semantic-conventions = "0.41b0" +opentelemetry-api = "1.24.0" +opentelemetry-semantic-conventions = "0.45b0" typing-extensions = ">=3.7.4" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.41b0" +version = "0.45b0" description = "OpenTelemetry Semantic Conventions" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "opentelemetry_semantic_conventions-0.41b0-py3-none-any.whl", hash = "sha256:45404391ed9e50998183a4925ad1b497c01c143f06500c3b9c3d0013492bb0f2"}, - {file = "opentelemetry_semantic_conventions-0.41b0.tar.gz", hash = "sha256:0ce5b040b8a3fc816ea5879a743b3d6fe5db61f6485e4def94c6ee4d402e1eb7"}, + {file = "opentelemetry_semantic_conventions-0.45b0-py3-none-any.whl", hash = "sha256:a4a6fb9a7bacd9167c082aa4681009e9acdbfa28ffb2387af50c2fef3d30c864"}, + {file = "opentelemetry_semantic_conventions-0.45b0.tar.gz", hash = "sha256:7c84215a44ac846bc4b8e32d5e78935c5c43482e491812a0bb8aaf87e4d92118"}, ] [[package]] @@ -8349,11 +8332,11 @@ manifest = ["manifest-ml"] pii = ["presidio_analyzer", "presidio_anonymizer"] profanity = ["alt-profanity-check"] sql = ["sqlalchemy", "sqlglot", "sqlvalidator"] -summary = ["nltk", "thefuzz"] +summary = ["thefuzz"] toxic-language = ["torch", "transformers"] vectordb = ["faiss-cpu", "numpy"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "fe6b6c42df209d16b013d3b8cd070399bb5804d4c10a28868a92d513c7fc520a" +content-hash = "f2c26b1ad671f41fb695448f5f007d9f29609581db2a23301169a4d0a5ab11f1" diff --git a/pyproject.toml b/pyproject.toml index 64d9bd868..07e80ad56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,12 +26,12 @@ rstr = "^3.2.2" typing-extensions = "^4.8.0" python-dateutil = "^2.8.2" tiktoken = ">=0.5.1" +nltk = "^3.8.1" sqlvalidator = {version = "^0.0.20", optional = true} sqlalchemy = {version = ">=2.0.9", optional = true} sqlglot = {version = "^19.0.3", optional = true} thefuzz = {version = "^0.20.0", optional = true} -nltk = {version = "^3.8.1", optional = true} faiss-cpu = {version = "^1.7.4", optional = true} numpy = {version = ">=1.24", optional = true} alt-profanity-check = {version = "^1.3.1", optional = true} @@ -50,19 +50,19 @@ huggingface_hub = {version = "^0.19.3", optional = true} pydash = "^7.0.6" docspec_python = "2.2.1" pydoc-markdown = "4.8.2" -opentelemetry-sdk = "1.20.0" -opentelemetry-exporter-otlp-proto-grpc = "1.20.0" -opentelemetry-exporter-otlp-proto-http = "1.20.0" langchain-core = ">=0.1,<0.3" coloredlogs = "^15.0.1" requests = "^2.31.0" -guardrails-api-client = "^0.1.1" +guardrails-api-client = "^0.2.1" jwt = "^1.3.1" pip = ">=22" +opentelemetry-sdk = "^1.24.0" +opentelemetry-exporter-otlp-proto-grpc = "^1.24.0" +opentelemetry-exporter-otlp-proto-http = "^1.24.0" [tool.poetry.extras] sql = ["sqlvalidator", "sqlalchemy", "sqlglot"] -summary = ["thefuzz", "nltk"] +summary = ["thefuzz"] vectordb = ["faiss-cpu", "numpy"] profanity = ["alt-profanity-check"] detect-secrets = ["detect-secrets"] diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 4551f8f4e..149ddd62d 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -1,17 +1,27 @@ -# 2 tests +# 3 tests # 1. Test streaming with OpenAICallable (mock openai.Completion.create) # 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create) -# Using the LowerCase Validator +# 3. Test string schema streaming +# Using the LowerCase Validator, and a custom validator to show new streaming behavior import json -from typing import Iterable +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import openai import pytest from pydantic import BaseModel, Field import guardrails as gd +from guardrails.utils.casting_utils import to_int from guardrails.utils.openai_utils import OPENAI_VERSION -from guardrails.validator_base import OnFailAction +from guardrails.validator_base import ( + ErrorSpan, + FailResult, + OnFailAction, + PassResult, + ValidationResult, + Validator, + register_validator, +) from guardrails.validators import LowerCase expected_raw_output = {"statement": "I am DOING well, and I HOPE you aRe too."} @@ -20,6 +30,65 @@ expected_filter_refrain_output = {} +@register_validator(name="minsentencelength", data_type=["string", "list"]) +class MinSentenceLengthValidator(Validator): + def __init__( + self, + min: Optional[int] = None, + max: Optional[int] = None, + on_fail: Optional[Callable] = None, + ): + super().__init__( + on_fail=on_fail, + min=min, + max=max, + ) + self._min = to_int(min) + self._max = to_int(max) + + def sentence_split(self, value): + return list(map(lambda x: x + ".", value.split(".")[:-1])) + + def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult: + sentences = self.sentence_split(value) + error_spans = [] + index = 0 + for sentence in sentences: + if len(sentence) < self._min: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + ) + if len(sentence) > self._max: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length greater than {self._max}. " + f"Please return a shorter output, " + f"that is shorter than {self._max} characters.", + ) + ) + index = index + len(sentence) + if len(error_spans) > 0: + return FailResult( + validated_chunk=value, + error_spans=error_spans, + error_message=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + return PassResult(validated_chunk=value) + + def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult: + return super().validate_stream(chunk, metadata, **kwargs) + + class Delta: content: str @@ -49,20 +118,18 @@ def __init__(self, choices, model): self.model = model -def mock_openai_completion_create(): +def mock_openai_completion_create(chunks): # Returns a generator - chunks = [ - '{"statement":', - ' "I am DOING', - " well, and I", - " HOPE you aRe", - ' too."}', - ] - def gen(): + index = 0 for chunk in chunks: + index = index + 1 + # finished = index == len(chunks) + # finish_reason = "stop" if finished else None + # print("FINISH REASON", finish_reason) if OPENAI_VERSION.startswith("0"): yield { + # TODO: for some reason using finish_reason here breaks everything "choices": [{"text": chunk, "finish_reason": None}], "model": "OpenAI model name", } @@ -72,6 +139,7 @@ def gen(): Choice( text=chunk, delta=Delta(content=""), + # TODO: for some reason using finish_reason here breaks everything # noqa finish_reason=None, ) ], @@ -81,24 +149,22 @@ def gen(): return gen() -def mock_openai_chat_completion_create(): +def mock_openai_chat_completion_create(chunks): # Returns a generator - chunks = [ - '{"statement":', - ' "I am DOING', - " well, and I", - " HOPE you aRe", - ' too."}', - ] - def gen(): + index = 0 for chunk in chunks: + index = index + 1 + # finished = index == len(chunks) + # finish_reason = "stop" if finished else None + # print("FINISH REASON", finish_reason) if OPENAI_VERSION.startswith("0"): yield { "choices": [ { "index": 0, "delta": {"content": chunk}, + # TODO: for some reason using finish_reason here breaks everything # noqa "finish_reason": None, } ] @@ -109,6 +175,7 @@ def gen(): Choice( text="", delta=Delta(content=chunk), + # TODO: for some reason using finish_reason here breaks everything # noqa finish_reason=None, ) ], @@ -146,25 +213,57 @@ class LowerCaseRefrain(BaseModel): ) +expected_minsentence_noop_output = "" + + +class MinSentenceLengthNoOp(BaseModel): + statement: str = Field( + description="Validates whether the text is in lower case.", + validators=[MinSentenceLengthValidator(on_fail=OnFailAction.NOOP)], + ) + + +STR_PROMPT = "Say something nice to me." + PROMPT = """ Say something nice to me. ${gr.complete_json_suffix} """ +JSON_LLM_CHUNKS = [ + '{"statement":', + ' "I am DOING', + " well, and I", + " HOPE you aRe", + ' too."}', +] + @pytest.mark.parametrize( - "op_class, expected_validated_output", + "guard, expected_validated_output", [ - (LowerCaseNoop, expected_noop_output), - (LowerCaseFix, expected_fix_output), - (LowerCaseFilter, expected_filter_refrain_output), - (LowerCaseRefrain, expected_filter_refrain_output), + ( + gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + expected_noop_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + expected_fix_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + expected_filter_refrain_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + expected_filter_refrain_output, + ), ], ) def test_streaming_with_openai_callable( mocker, - op_class, + guard, expected_validated_output, ): """Test streaming with OpenAICallable. @@ -173,17 +272,15 @@ def test_streaming_with_openai_callable( """ if OPENAI_VERSION.startswith("0"): mocker.patch( - "openai.Completion.create", return_value=mock_openai_completion_create() + "openai.Completion.create", + return_value=mock_openai_completion_create(JSON_LLM_CHUNKS), ) else: mocker.patch( "openai.resources.Completions.create", - return_value=mock_openai_completion_create(), + return_value=mock_openai_completion_create(JSON_LLM_CHUNKS), ) - # Create a guard object - guard = gd.Guard.from_pydantic(output_class=op_class, prompt=PROMPT) - method = ( openai.Completion.create if OPENAI_VERSION.startswith("0") @@ -210,17 +307,29 @@ def test_streaming_with_openai_callable( @pytest.mark.parametrize( - "op_class, expected_validated_output", + "guard, expected_validated_output", [ - (LowerCaseNoop, expected_noop_output), - (LowerCaseFix, expected_fix_output), - (LowerCaseFilter, expected_filter_refrain_output), - (LowerCaseRefrain, expected_filter_refrain_output), + ( + gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + expected_noop_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + expected_fix_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + expected_filter_refrain_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + expected_filter_refrain_output, + ), ], ) def test_streaming_with_openai_chat_callable( mocker, - op_class, + guard, expected_validated_output, ): """Test streaming with OpenAIChatCallable. @@ -230,17 +339,14 @@ def test_streaming_with_openai_chat_callable( if OPENAI_VERSION.startswith("0"): mocker.patch( "openai.ChatCompletion.create", - return_value=mock_openai_chat_completion_create(), + return_value=mock_openai_chat_completion_create(JSON_LLM_CHUNKS), ) else: mocker.patch( "openai.resources.chat.completions.Completions.create", - return_value=mock_openai_chat_completion_create(), + return_value=mock_openai_chat_completion_create(JSON_LLM_CHUNKS), ) - # Create a guard object - guard = gd.Guard.from_pydantic(output_class=op_class, prompt=PROMPT) - method = ( openai.ChatCompletion.create if OPENAI_VERSION.startswith("0") @@ -265,3 +371,99 @@ def test_streaming_with_openai_chat_callable( assert actual_output.raw_llm_output == json.dumps(expected_raw_output) assert actual_output.validated_output == expected_validated_output + + +STR_LLM_CHUNKS = [ + # 38 characters + "This sentence is simply just ", + "too long." + # 25 characters long + "This ", + "sentence ", + "is 2 ", + "short." + # 29 characters long + "This sentence is just ", + "right.", +] + + +@pytest.mark.parametrize( + "guard, expected_error_spans", + [ + ( + gd.Guard.from_string( + # only the middle sentence should pass + validators=[ + MinSentenceLengthValidator(26, 30, on_fail=OnFailAction.NOOP) + ], + prompt=STR_PROMPT, + ), + # each value is a tuple + # first is expected text inside span + # second is the reason for failure + [ + [ + "This sentence is simply just too long.", + ( + "Sentence has length greater than 30. " + "Please return a shorter output, " + "that is shorter than 30 characters." + ), + ], + [ + "This sentence is 2 short.", + ( + "Sentence has length less than 26. " + "Please return a longer output, " + "that is shorter than 30 characters." + ), + ], + ], + ) + ], +) +def test_string_schema_streaming_with_openai_chat(mocker, guard, expected_error_spans): + """Test string schema streaming with OpenAIChatCallable. + + Mocks openai.ChatCompletion.create. + """ + if OPENAI_VERSION.startswith("0"): + mocker.patch( + "openai.ChatCompletion.create", + return_value=mock_openai_chat_completion_create(STR_LLM_CHUNKS), + ) + else: + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(STR_LLM_CHUNKS), + ) + + method = ( + openai.ChatCompletion.create + if OPENAI_VERSION.startswith("0") + else openai.chat.completions.create + ) + + method.__name__ = "mock_openai_chat_completion_create" + generator = guard( + method, + model="gpt-3.5-turbo", + max_tokens=10, + temperature=0, + stream=True, + ) + + assert isinstance(generator, Iterable) + + accumulated_output = "" + for op in generator: + accumulated_output += op.raw_llm_output + error_spans = guard.error_spans_in_output() + + # print spans + assert len(error_spans) == len(expected_error_spans) + for error_span, expected in zip(error_spans, expected_error_spans): + assert accumulated_output[error_span.start : error_span.end] == expected[0] + assert error_span.reason == expected[1] + # TODO assert something about these error spans diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index 89865567c..a5bf4bf8f 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -88,6 +88,7 @@ def test_happy_path(self, mocker): class TestPipProcess: def test_no_package_string_format(self, mocker): + mocker.patch("guardrails.cli.hub.install.os.environ", return_value={}) mock_logger_debug = mocker.patch("guardrails.cli.hub.utils.logger.debug") mock_sys_executable = mocker.patch("guardrails.cli.hub.install.sys.executable") @@ -109,12 +110,14 @@ def test_no_package_string_format(self, mocker): mock_logger_debug.assert_has_calls(debug_calls) mock_subprocess_check_output.assert_called_once_with( - [mock_sys_executable, "-m", "pip", "inspect", "--path=./install-here"] + [mock_sys_executable, "-m", "pip", "inspect", "--path=./install-here"], + env={}, ) assert response == "string output" def test_json_format(self, mocker): + mocker.patch("guardrails.cli.hub.install.os.environ", return_value={}) mock_logger_debug = mocker.patch("guardrails.cli.hub.install.logger.debug") mock_sys_executable = mocker.patch("guardrails.cli.hub.install.sys.executable") @@ -147,7 +150,7 @@ def parsebytes(self, *args): mock_logger_debug.assert_has_calls(debug_calls) mock_subprocess_check_output.assert_called_once_with( - [mock_sys_executable, "-m", "pip", "show", "pip"] + [mock_sys_executable, "-m", "pip", "show", "pip"], env={} ) assert response == {"output": "json"} @@ -708,6 +711,7 @@ def test_install_hub_module(mocker): flags=["--path=mock/install/directory"], format="json", quiet=False, + no_color=True, ), call("install", "rstr", quiet=False), call("install", "openai<2", quiet=False), @@ -783,6 +787,7 @@ def test_quiet_install(mocker): flags=["--path=mock/install/directory"], format="json", quiet=True, + no_color=True, ), call("install", "rstr", quiet=True), call("install", "openai<2", quiet=True), diff --git a/tests/unit_tests/test_async_validator_service.py b/tests/unit_tests/test_async_validator_service.py index 34431fe14..8e314c095 100644 --- a/tests/unit_tests/test_async_validator_service.py +++ b/tests/unit_tests/test_async_validator_service.py @@ -91,7 +91,7 @@ async def test_async_validate_with_children(mocker): assert run_validators_mock.call_count == 1 run_validators_mock.assert_called_once_with( - iteration, field_validation, True, {}, "$.mock-parent-key" + iteration, field_validation, True, {}, "$.mock-parent-key", stream=False ) assert validated_value == "run_validators_mock" @@ -118,7 +118,7 @@ async def test_async_validate_without_children(mocker): assert run_validators_mock.call_count == 1 run_validators_mock.assert_called_once_with( - iteration, empty_field_validation, True, {}, "$.mock-key" + iteration, empty_field_validation, True, {}, "$.mock-key", stream=False ) assert validated_value == "run_validators_mock" @@ -186,7 +186,9 @@ async def test_run_validators(mocker): (OnFailAction.NOOP, [noop_validator_1, noop_validator_2]), ] - def mock_run_validator(iteration, validator, value, metadata, property_path): + def mock_run_validator( + iteration, validator, value, metadata, property_path, stream + ): return ValidatorLogs( registered_name=validator.name, validator_name=validator.name, @@ -234,6 +236,7 @@ async def mock_gather(*args): empty_field_validation.value, {}, "$", + False, ) assert run_validator_mock.call_count == 3 diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 3f6f2c8c9..e8cf1ea73 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -899,11 +899,11 @@ async def mock_llm_api(*args, **kwargs): [ ( OnFailAction.REASK, - "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This is')] path=None", # noqa - "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This also')] path=None", # noqa + "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This is', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This also', error_spans=None)] path=None", # noqa ), ( OnFailAction.FILTER, @@ -1044,11 +1044,11 @@ def test_input_validation_fail( [ ( OnFailAction.REASK, - "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This is')] path=None", # noqa - "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This also')] path=None", # noqa + "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This is', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This also', error_spans=None)] path=None", # noqa ), ( OnFailAction.FILTER, diff --git a/tests/unit_tests/utils/test_api_utils.py b/tests/unit_tests/utils/test_api_utils.py new file mode 100644 index 000000000..4a74cea67 --- /dev/null +++ b/tests/unit_tests/utils/test_api_utils.py @@ -0,0 +1,19 @@ +from guardrails.utils.api_utils import extract_serializeable_metadata + + +def test_extract_serializeable_metadata(): + def baz(): + print("baz") + + class NonMeta: + data = "data" + + metadata = { + "foo": "bar", + "baz": baz, + "non_meta": NonMeta(), + } + + extracted_metadata = extract_serializeable_metadata(metadata) + + assert extracted_metadata == {"foo": "bar"}