diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index 287d344cc..503cefb9a 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -99,8 +99,8 @@ def reask_prompts(self) -> Stack[Optional[str]]: @property def instructions(self) -> Optional[str]: - """The instructions as provided by the user when initializing or calling - the Guard.""" + """The instructions as provided by the user when initializing or + calling the Guard.""" return self.inputs.instructions @property diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py index 1eb9e12dd..ef913bafd 100644 --- a/guardrails/classes/history/iteration.py +++ b/guardrails/classes/history/iteration.py @@ -15,6 +15,7 @@ from guardrails.utils.logs_utils import ValidatorLogs from guardrails.utils.pydantic_utils import ArbitraryModel from guardrails.utils.reask_utils import ReAsk +from guardrails.validator_base import ErrorSpan class Iteration(ArbitraryModel): @@ -155,6 +156,14 @@ def failed_validations(self) -> List[ValidatorLogs]: iteration.""" return self.outputs.failed_validations + @property + def error_spans_in_output(self) -> List[ErrorSpan]: + """The error spans from the LLM response. + + These indices are relative to the complete LLM output. + """ + return self.outputs.error_spans_in_output + @property def status(self) -> str: """Representation of the end state of this iteration. diff --git a/guardrails/classes/history/outputs.py b/guardrails/classes/history/outputs.py index 9d4d19544..953ed6f4b 100644 --- a/guardrails/classes/history/outputs.py +++ b/guardrails/classes/history/outputs.py @@ -1,5 +1,4 @@ from typing import Dict, List, Optional, Sequence, Union - from pydantic import Field from typing_extensions import deprecated @@ -8,7 +7,7 @@ from guardrails.utils.logs_utils import ValidatorLogs from guardrails.utils.pydantic_utils import ArbitraryModel from guardrails.utils.reask_utils import ReAsk -from guardrails.validator_base import FailResult +from guardrails.validator_base import ErrorSpan, FailResult class Outputs(ArbitraryModel): @@ -75,6 +74,30 @@ def failed_validations(self) -> List[ValidatorLogs]: ] ) + @property + def error_spans_in_output(self) -> List[ErrorSpan]: + """The error spans from the LLM response. + + These indices are relative to the complete LLM output. + """ + total_len = 0 + spans_in_output = [] + for log in self.validator_logs: + result = log.validation_result + if isinstance(result, FailResult): + if result.error_spans is not None: + for error_span in result.error_spans: + spans_in_output.append( + ErrorSpan( + start=error_span.start + total_len, + end=error_span.end + total_len, + reason=error_span.reason, + ) + ) + if result and result.validated_chunk is not None: + total_len += len(result.validated_chunk) + return spans_in_output + @property def status(self) -> str: """Representation of the end state of the validation run. diff --git a/guardrails/guard.py b/guardrails/guard.py index 8940be790..948514eda 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1144,6 +1144,15 @@ async def _async_parse( return ValidationOutcome[OT].from_guard_history(call) + def error_spans_in_output(self): + try: + call = self.history[0] + iter = call.iterations[0] + llm_spans = iter.error_spans_in_output + return llm_spans + except (AttributeError, TypeError): + return [] + @deprecated( """The `with_prompt_validation` method is deprecated, and will be removed in 0.5.x. Instead, please use diff --git a/guardrails/run/runner.py b/guardrails/run/runner.py index 9688d43ea..eba815973 100644 --- a/guardrails/run/runner.py +++ b/guardrails/run/runner.py @@ -547,17 +547,29 @@ def validate( index: int, parsed_output: Any, output_schema: Schema, + stream: Optional[bool] = False, **kwargs, ): """Validate the output.""" - validated_output = output_schema.validate( - iteration, - parsed_output, - self.metadata, - attempt_number=index, - disable_tracer=self._disable_tracer, - **kwargs, - ) + if isinstance(output_schema, StringSchema): + validated_output = output_schema.validate( + iteration, + parsed_output, + self.metadata, + index, + self._disable_tracer, + stream, + **kwargs, + ) + else: + validated_output = output_schema.validate( + iteration, + parsed_output, + self.metadata, + attempt_number=index, + disable_tracer=self._disable_tracer, + **kwargs, + ) return validated_output diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 890c3fd9c..5dc842eda 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -15,6 +15,7 @@ from guardrails.schema import Schema, StringSchema from guardrails.utils.openai_utils import OPENAI_VERSION from guardrails.utils.reask_utils import SkeletonReAsk +from guardrails.constants import pass_status class StreamRunner(Runner): @@ -153,47 +154,118 @@ def step( verified = set() # Loop over the stream # and construct "fragments" of concatenated chunks - for chunk in stream: - # 1. Get the text from the chunk and append to fragment - chunk_text = self.get_chunk_text(chunk, api) - fragment += chunk_text + # for now, handle string and json schema differently - # 2. Parse the fragment - parsed_fragment, move_to_next = self.parse( - index, fragment, output_schema, verified - ) - if move_to_next: - # Continue to next chunk - continue + if isinstance(output_schema, StringSchema): + stream_finished = False + last_chunk_text = "" + for chunk in stream: + # 1. Get the text from the chunk and append to fragment + chunk_text = self.get_chunk_text(chunk, api) + last_chunk_text = chunk_text + finished = self.is_last_chunk(chunk, api) + if finished: + stream_finished = True + fragment += chunk_text - # 3. Run output validation - validated_fragment = self.validate( - iteration, - index, - parsed_fragment, - output_schema, - validate_subschema=True, - ) - if isinstance(validated_fragment, SkeletonReAsk): - raise ValueError( - "Received fragment schema is an invalid sub-schema " - "of the expected output JSON schema." + # 2. Parse the chunk + parsed_chunk, move_to_next = self.parse( + index, chunk_text, output_schema, verified + ) + if move_to_next: + # Continue to next chunk + continue + validated_text = self.validate( + iteration, + index, + parsed_chunk, + output_schema, + True, + validate_subschema=True, + # if it is the last chunk, validate everything that's left + remainder=finished, ) + if isinstance(validated_text, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) - # 4. Introspect: inspect the validated fragment for reasks - reasks, valid_op = self.introspect(index, validated_fragment, output_schema) - if reasks: - raise ValueError( - "Reasks are not yet supported with streaming. Please " - "remove reasks from schema or disable streaming." + # 4. Introspect: inspect the validated fragment for reasks + reasks, valid_op = self.introspect(index, validated_text, output_schema) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + # 5. Convert validated fragment to a pretty JSON string + passed = call_log.status == pass_status + yield ValidationOutcome( + # The chunk or the whole output? + raw_llm_output=chunk_text, + validated_output=validated_text, + validation_passed=passed, + ) + # handle case where generator doesn't give finished status + if not stream_finished: + last_result = self.validate( + iteration, + index, + "", + output_schema, + True, + validate_subschema=True, + remainder=True, ) + if len(last_result) > 0: + passed = call_log.status == pass_status + yield ValidationOutcome( + raw_llm_output=last_chunk_text, + validated_output=last_result, + validation_passed=passed, + ) + # handle non string schema + else: + for chunk in stream: + # 1. Get the text from the chunk and append to fragment + chunk_text = self.get_chunk_text(chunk, api) + fragment += chunk_text - # 5. Convert validated fragment to a pretty JSON string - yield ValidationOutcome( - raw_llm_output=fragment, - validated_output=validated_fragment, - validation_passed=validated_fragment is not None, - ) + parsed_fragment, move_to_next = self.parse( + index, fragment, output_schema, verified + ) + if move_to_next: + # Continue to next chunk + continue + validated_fragment = self.validate( + iteration, + index, + parsed_fragment, + output_schema, + validate_subschema=True, + ) + if isinstance(validated_fragment, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) + + # 4. Introspect: inspect the validated fragment for reasks + reasks, valid_op = self.introspect( + index, validated_fragment, output_schema + ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + + # 5. Convert validated fragment to a pretty JSON string + yield ValidationOutcome( + raw_llm_output=fragment, + validated_output=validated_fragment, + validation_passed=validated_fragment is not None, + ) # Finally, add to logs iteration.outputs.raw_output = fragment @@ -201,6 +273,32 @@ def step( iteration.outputs.validation_response = validated_fragment iteration.outputs.guarded_output = valid_op + def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool: + """Detect if chunk is final chunk.""" + if isinstance(api, OpenAICallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, OpenAIChatCallable): + if OPENAI_VERSION.startswith("0"): + finished = chunk["choices"][0]["finish_reason"] + return finished is not None + else: + finished = chunk.choices[0].finish_reason + return finished is not None + elif isinstance(api, LiteLLMCallable): + finished = chunk.choices[0].finish_reason + return finished is not None + else: + try: + finished = chunk.choices[0].finish_reason + return finished is not None + except (AttributeError, TypeError): + return False + def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: """Get the text from a chunk.""" chunk_text = "" diff --git a/guardrails/schema/string_schema.py b/guardrails/schema/string_schema.py index cbff5d5d1..200666041 100644 --- a/guardrails/schema/string_schema.py +++ b/guardrails/schema/string_schema.py @@ -134,6 +134,7 @@ def validate( metadata: Dict, attempt_number: int = 0, disable_tracer: Optional[bool] = True, + stream: Optional[bool] = False, **kwargs, ) -> Any: """Validate a dictionary of data against the schema. @@ -160,19 +161,20 @@ def validate( dummy_key: data, }, ) - validated_response, metadata = validator_service.validate( value=data, metadata=metadata, validator_setup=validation, iteration=iteration, disable_tracer=disable_tracer, + stream=stream, + **kwargs, ) validated_response = {dummy_key: validated_response} if check_refrain_in_dict(validated_response): - # If the data contains a `Refain` value, we return an empty + # If the data contains a `Refrain` value, we return an empty # dictionary. logger.debug("Refrain detected.") validated_response = {} diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 2b92915d9..6fb1cc9df 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -1,4 +1,5 @@ import inspect +import nltk from collections import defaultdict from copy import deepcopy from enum import Enum @@ -175,6 +176,29 @@ class Refrain: pass +# functions to get chunks + + +def split_sentence_str(chunk: str): + if "." not in chunk: + return [] + fragments = chunk.split(".") + return [fragments[0] + ".", ".".join(fragments[1:])] + + +def split_sentence_nltk(chunk: str): + # using the sentence tokenizer is expensive + # we check for a . to avoid wastefully calling the tokenizer + if "." not in chunk: + return [] + sentences = nltk.sent_tokenize(chunk) + if len(sentences) == 0: + return [] + # return the sentence + # then the remaining chunks that aren't finished accumulating + return [sentences[0], "".join(sentences[1:])] + + def check_refrain_in_list(schema: List) -> bool: """Checks if a Refrain object exists in a list. @@ -355,6 +379,9 @@ def get_validator(name: str): class ValidationResult(BaseModel): outcome: str metadata: Optional[Dict[str, Any]] = None + # value argument passed to validator.validate + # or validator.validate_stream + validated_chunk: Optional[Any] = None class PassResult(ValidationResult): @@ -367,11 +394,21 @@ class ValueOverrideSentinel: value_override: Optional[Any] = Field(default=ValueOverrideSentinel) +# specifies the start and end of segment of validate_chunk +class ErrorSpan(BaseModel): + start: int + end: int + # reason validation failed, specific to this chunk + reason: str + + class FailResult(ValidationResult): outcome: Literal["fail"] = "fail" error_message: str fix_value: Optional[Any] = None + # segments that caused validation to fail + error_spans: Optional[List[ErrorSpan]] = None class OnFailAction(str, Enum): @@ -390,6 +427,10 @@ class Validator(Runnable): rail_alias: str = "" + # chunking function returns empty list or list of 2 chunks + # first chunk is the chunk to validate + # second chunk is incomplete chunk that needs further accumulation + accumulated_chunks = [] run_in_separate_process = False override_value_on_pass = False required_metadata_keys = [] @@ -448,10 +489,49 @@ def __init__( self.rail_alias in validators_registry ), f"Validator {self.__class__.__name__} is not registered. " + def chunking_function(self, chunk: str): + return split_sentence_str(chunk) + def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: """Validates a value and return a validation result.""" raise NotImplementedError + def validate_stream( + self, chunk: Any, metadata: Dict[str, Any], **kwargs + ) -> Optional[ValidationResult]: + """Validates a chunk emitted by an LLM. If the LLM chunk is smaller + than the validator's chunking strategy, it will be accumulated until it + reaches the desired size. In the meantime, the validator will return + None. + + If the LLM chunk is larger than the validator's chunking + strategy, it will split it into validator-sized chunks and + validate each one, returning an array of validation results. + + Otherwise, the validator will validate the chunk and return the + result. + """ + # combine accumulated chunks and new [:-1]chunk + self.accumulated_chunks.append(chunk) + accumulated_text = "".join(self.accumulated_chunks) + # check if enough chunks have accumulated for validation + splitcontents = self.chunking_function(accumulated_text) + + # if remainder kwargs is passed, validate remainder regardless + remainder = kwargs.get("remainder", False) + if remainder: + splitcontents = [accumulated_text, ""] + if len(splitcontents) == 0: + return None + [chunk_to_validate, new_accumulated_chunks] = splitcontents + self.accumulated_chunks = [new_accumulated_chunks] + # exclude last chunk, because it may not be a complete chunk + validation_result = self.validate(chunk_to_validate, metadata) + # if validate doesn't set validated chunk, we set it + if validation_result.validated_chunk is None: + validation_result.validated_chunk = chunk_to_validate + return validation_result + def to_prompt(self, with_keywords: bool = True) -> str: """Convert the validator to a prompt. diff --git a/guardrails/validator_service.py b/guardrails/validator_service.py index 14f610846..fa5e60c2e 100644 --- a/guardrails/validator_service.py +++ b/guardrails/validator_service.py @@ -41,8 +41,14 @@ def __init__(self, disable_tracer: Optional[bool] = True): # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, # but is relatively unsupported on Windows. def execute_validator( - self, validator: Validator, value: Any, metadata: Optional[Dict] - ) -> ValidationResult: + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + **kwargs, + ) -> Optional[ValidationResult]: + validate_func = validator.validate_stream if stream else validator.validate traced_validator = trace_validator( validator_name=validator.rail_alias, obj_id=id(validator), @@ -50,8 +56,8 @@ def execute_validator( # namespace=validator.namespace, on_fail_descriptor=validator.on_fail_descriptor, **validator._kwargs, - )(validator.validate) - result = traced_validator(value, metadata) + )(validate_func) + result = traced_validator(value, metadata, **kwargs) return result def perform_correction( @@ -112,6 +118,8 @@ def run_validator( value: Any, metadata: Dict, property_path: str, + stream: Optional[bool] = False, + **kwargs, ) -> ValidatorLogs: validator_class_name = validator.__class__.__name__ validator_logs = ValidatorLogs( @@ -123,7 +131,7 @@ def run_validator( iteration.outputs.validator_logs.append(validator_logs) start_time = datetime.now() - result = self.execute_validator(validator, value, metadata) + result = self.execute_validator(validator, value, metadata, stream, **kwargs) end_time = datetime.now() if result is None: result = PassResult() @@ -161,13 +169,14 @@ def run_validators( value: Any, metadata: Dict[str, Any], property_path: str, + stream: Optional[bool] = False, + **kwargs, ) -> Tuple[Any, Dict[str, Any]]: # Validate the field for validator in validator_setup.validators: validator_logs = self.run_validator( - iteration, validator, value, metadata, property_path + iteration, validator, value, metadata, property_path, stream, **kwargs ) - result = validator_logs.validation_result if isinstance(result, FailResult): value = self.perform_correction( @@ -179,11 +188,11 @@ def run_validators( and result.value_override is not result.ValueOverrideSentinel ): value = result.value_override - else: + elif not stream: raise RuntimeError(f"Unexpected result type {type(result)}") validator_logs.value_after_validation = value - if result.metadata is not None: + if result and result.metadata is not None: metadata = result.metadata if isinstance(value, (Refrain, Filter, ReAsk)): @@ -228,6 +237,29 @@ def validate( value, metadata = self.run_validators( iteration, validator_setup, value, metadata, property_path ) + return value, metadata + + def validate_stream( + self, + value: Any, + metadata: dict, + validator_setup: FieldValidation, + iteration: Iteration, + path: str = "$", + **kwargs, + ) -> Tuple[Any, dict]: + property_path = ( + f"{path}.{validator_setup.key}" + if key_not_empty(validator_setup.key) + else path + ) + # I assume validate stream doesn't need validate_dependents + # because right now we're only handling StringSchema + + # Validate the field + value, metadata = self.run_validators( + iteration, validator_setup, value, metadata, property_path, True, **kwargs + ) return value, metadata @@ -411,9 +443,15 @@ def validate( validator_setup: FieldValidation, iteration: Iteration, disable_tracer: Optional[bool] = True, + stream: Optional[bool] = False, + **kwargs, ): process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - + if stream: + sequential_validator_service = SequentialValidatorService(disable_tracer) + return sequential_validator_service.validate_stream( + value, metadata, validator_setup, iteration, **kwargs + ) try: loop = asyncio.get_event_loop() except RuntimeError: diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 4551f8f4e..623ce4a9f 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -1,17 +1,27 @@ -# 2 tests +# 3 tests # 1. Test streaming with OpenAICallable (mock openai.Completion.create) # 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create) -# Using the LowerCase Validator +# 3. Test string schema streaming +# Using the LowerCase Validator, and a custom validator to show new streaming behavior import json -from typing import Iterable +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import openai import pytest from pydantic import BaseModel, Field import guardrails as gd +from guardrails.utils.casting_utils import to_int from guardrails.utils.openai_utils import OPENAI_VERSION -from guardrails.validator_base import OnFailAction +from guardrails.validator_base import ( + ErrorSpan, + FailResult, + OnFailAction, + PassResult, + ValidationResult, + Validator, + register_validator, +) from guardrails.validators import LowerCase expected_raw_output = {"statement": "I am DOING well, and I HOPE you aRe too."} @@ -20,6 +30,65 @@ expected_filter_refrain_output = {} +@register_validator(name="minsentencelength", data_type=["string", "list"]) +class MinSentenceLengthValidator(Validator): + def __init__( + self, + min: Optional[int] = None, + max: Optional[int] = None, + on_fail: Optional[Callable] = None, + ): + super().__init__( + on_fail=on_fail, + min=min, + max=max, + ) + self._min = to_int(min) + self._max = to_int(max) + + def sentence_split(self, value): + return list(map(lambda x: x + ".", value.split(".")[:-1])) + + def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult: + sentences = self.sentence_split(value) + error_spans = [] + index = 0 + for sentence in sentences: + if len(sentence) < self._min: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + ) + if len(sentence) > self._max: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length greater than {self._max}. " + f"Please return a shorter output, " + f"that is shorter than {self._max} characters.", + ) + ) + index = index + len(sentence) + if len(error_spans) > 0: + return FailResult( + validated_chunk=value, + error_spans=error_spans, + error_message=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + return PassResult() + + def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult: + return super().validate_stream(chunk, metadata, **kwargs) + + class Delta: content: str @@ -49,20 +118,18 @@ def __init__(self, choices, model): self.model = model -def mock_openai_completion_create(): +def mock_openai_completion_create(chunks): # Returns a generator - chunks = [ - '{"statement":', - ' "I am DOING', - " well, and I", - " HOPE you aRe", - ' too."}', - ] - def gen(): + index = 0 for chunk in chunks: + index = index + 1 + # finished = index == len(chunks) + # finish_reason = "stop" if finished else None + # print("FINISH REASON", finish_reason) if OPENAI_VERSION.startswith("0"): yield { + # TODO: for some reason using finish_reason here breaks everything "choices": [{"text": chunk, "finish_reason": None}], "model": "OpenAI model name", } @@ -72,6 +139,7 @@ def gen(): Choice( text=chunk, delta=Delta(content=""), + # TODO: for some reason using finish_reason here breaks everything # noqa finish_reason=None, ) ], @@ -81,24 +149,22 @@ def gen(): return gen() -def mock_openai_chat_completion_create(): +def mock_openai_chat_completion_create(chunks): # Returns a generator - chunks = [ - '{"statement":', - ' "I am DOING', - " well, and I", - " HOPE you aRe", - ' too."}', - ] - def gen(): + index = 0 for chunk in chunks: + index = index + 1 + # finished = index == len(chunks) + # finish_reason = "stop" if finished else None + # print("FINISH REASON", finish_reason) if OPENAI_VERSION.startswith("0"): yield { "choices": [ { "index": 0, "delta": {"content": chunk}, + # TODO: for some reason using finish_reason here breaks everything # noqa "finish_reason": None, } ] @@ -109,6 +175,7 @@ def gen(): Choice( text="", delta=Delta(content=chunk), + # TODO: for some reason using finish_reason here breaks everything # noqa finish_reason=None, ) ], @@ -146,25 +213,57 @@ class LowerCaseRefrain(BaseModel): ) +expected_minsentence_noop_output = "" + + +class MinSentenceLengthNoOp(BaseModel): + statement: str = Field( + description="Validates whether the text is in lower case.", + validators=[MinSentenceLengthValidator(on_fail=OnFailAction.NOOP)], + ) + + +STR_PROMPT = "Say something nice to me." + PROMPT = """ Say something nice to me. ${gr.complete_json_suffix} """ +JSON_LLM_CHUNKS = [ + '{"statement":', + ' "I am DOING', + " well, and I", + " HOPE you aRe", + ' too."}', +] + @pytest.mark.parametrize( - "op_class, expected_validated_output", + "guard, expected_validated_output", [ - (LowerCaseNoop, expected_noop_output), - (LowerCaseFix, expected_fix_output), - (LowerCaseFilter, expected_filter_refrain_output), - (LowerCaseRefrain, expected_filter_refrain_output), + ( + gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + expected_noop_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + expected_fix_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + expected_filter_refrain_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + expected_filter_refrain_output, + ), ], ) def test_streaming_with_openai_callable( mocker, - op_class, + guard, expected_validated_output, ): """Test streaming with OpenAICallable. @@ -173,17 +272,15 @@ def test_streaming_with_openai_callable( """ if OPENAI_VERSION.startswith("0"): mocker.patch( - "openai.Completion.create", return_value=mock_openai_completion_create() + "openai.Completion.create", + return_value=mock_openai_completion_create(JSON_LLM_CHUNKS), ) else: mocker.patch( "openai.resources.Completions.create", - return_value=mock_openai_completion_create(), + return_value=mock_openai_completion_create(JSON_LLM_CHUNKS), ) - # Create a guard object - guard = gd.Guard.from_pydantic(output_class=op_class, prompt=PROMPT) - method = ( openai.Completion.create if OPENAI_VERSION.startswith("0") @@ -210,17 +307,29 @@ def test_streaming_with_openai_callable( @pytest.mark.parametrize( - "op_class, expected_validated_output", + "guard, expected_validated_output", [ - (LowerCaseNoop, expected_noop_output), - (LowerCaseFix, expected_fix_output), - (LowerCaseFilter, expected_filter_refrain_output), - (LowerCaseRefrain, expected_filter_refrain_output), + ( + gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + expected_noop_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + expected_fix_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + expected_filter_refrain_output, + ), + ( + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + expected_filter_refrain_output, + ), ], ) def test_streaming_with_openai_chat_callable( mocker, - op_class, + guard, expected_validated_output, ): """Test streaming with OpenAIChatCallable. @@ -230,17 +339,14 @@ def test_streaming_with_openai_chat_callable( if OPENAI_VERSION.startswith("0"): mocker.patch( "openai.ChatCompletion.create", - return_value=mock_openai_chat_completion_create(), + return_value=mock_openai_chat_completion_create(JSON_LLM_CHUNKS), ) else: mocker.patch( "openai.resources.chat.completions.Completions.create", - return_value=mock_openai_chat_completion_create(), + return_value=mock_openai_chat_completion_create(JSON_LLM_CHUNKS), ) - # Create a guard object - guard = gd.Guard.from_pydantic(output_class=op_class, prompt=PROMPT) - method = ( openai.ChatCompletion.create if OPENAI_VERSION.startswith("0") @@ -265,3 +371,99 @@ def test_streaming_with_openai_chat_callable( assert actual_output.raw_llm_output == json.dumps(expected_raw_output) assert actual_output.validated_output == expected_validated_output + + +STR_LLM_CHUNKS = [ + # 38 characters + "This sentence is simply just ", + "too long." + # 25 characters long + "This ", + "sentence ", + "is 2 ", + "short." + # 29 characters long + "This sentence is just ", + "right.", +] + + +@pytest.mark.parametrize( + "guard, expected_error_spans", + [ + ( + gd.Guard.from_string( + # only the middle sentence should pass + validators=[ + MinSentenceLengthValidator(26, 30, on_fail=OnFailAction.NOOP) + ], + prompt=STR_PROMPT, + ), + # each value is a tuple + # first is expected text inside span + # second is the reason for failure + [ + [ + "This sentence is simply just too long.", + ( + "Sentence has length greater than 30. " + "Please return a shorter output, " + "that is shorter than 30 characters." + ), + ], + [ + "This sentence is 2 short.", + ( + "Sentence has length less than 26. " + "Please return a longer output, " + "that is shorter than 30 characters." + ), + ], + ], + ) + ], +) +def test_string_schema_streaming_with_openai_chat(mocker, guard, expected_error_spans): + """Test string schema streaming with OpenAIChatCallable. + + Mocks openai.ChatCompletion.create. + """ + if OPENAI_VERSION.startswith("0"): + mocker.patch( + "openai.ChatCompletion.create", + return_value=mock_openai_chat_completion_create(STR_LLM_CHUNKS), + ) + else: + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(STR_LLM_CHUNKS), + ) + + method = ( + openai.ChatCompletion.create + if OPENAI_VERSION.startswith("0") + else openai.chat.completions.create + ) + + method.__name__ = "mock_openai_chat_completion_create" + generator = guard( + method, + model="gpt-3.5-turbo", + max_tokens=10, + temperature=0, + stream=True, + ) + + assert isinstance(generator, Iterable) + + accumulated_output = "" + for op in generator: + accumulated_output += op.raw_llm_output + error_spans = guard.error_spans_in_output() + + # print spans + assert len(error_spans) == len(expected_error_spans) + for error_span, expected in zip(error_spans, expected_error_spans): + assert accumulated_output[error_span.start : error_span.end] == expected[0] + assert error_span.reason == expected[1] + # TODO assert something about these error spans diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 3f6f2c8c9..e8cf1ea73 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -899,11 +899,11 @@ async def mock_llm_api(*args, **kwargs): [ ( OnFailAction.REASK, - "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This is')] path=None", # noqa - "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This also')] path=None", # noqa + "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This is', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This also', error_spans=None)] path=None", # noqa ), ( OnFailAction.FILTER, @@ -1044,11 +1044,11 @@ def test_input_validation_fail( [ ( OnFailAction.REASK, - "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='What kind')] path=None", # noqa - "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This is')] path=None", # noqa - "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, error_message='must be exactly two words', fix_value='This also')] path=None", # noqa + "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='What kind', error_spans=None)] path=None", # noqa + "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This is', error_spans=None)] path=None", # noqa + "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', metadata=None, validated_chunk=None, error_message='must be exactly two words', fix_value='This also', error_spans=None)] path=None", # noqa ), ( OnFailAction.FILTER,