Skip to content

Streaming chunk accumulation #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
29ee128
preliminary code and pseudocode
nichwch May 7, 2024
66a5333
add chunk accumulation strategy to Validator base class
nichwch May 7, 2024
7831466
handle case where llm chunk > validator chunk in validator class
nichwch May 8, 2024
2dbae2e
added some questions
nichwch May 8, 2024
1e3544d
change stream_runner to handle the result of iterable validate
nichwch May 8, 2024
e9084e1
format
nichwch May 8, 2024
1269f68
change validator base to use a chunking function instead of specifyin…
nichwch May 10, 2024
b454cd5
connect streaming all the way down call chain, include validated chun…
nichwch May 10, 2024
b64ab4e
change execute_validator to handle streaming
nichwch May 10, 2024
bf2bd32
make validate take stream parameter, remove validate_stream in top le…
nichwch May 10, 2024
c79e9b2
use wyatts sentence splitting strategy
nichwch May 10, 2024
4583cb9
import nltk
nichwch May 10, 2024
f1b4a88
use stream-enabled execute_validator
nichwch May 14, 2024
289745c
format
nichwch May 14, 2024
58d8eed
fix bug where json_schema was being called with streaming
nichwch May 15, 2024
947f476
conditionally use old logic for json_schema to avoid breaking json_sc…
nichwch May 16, 2024
8b2c154
validate remainders
nichwch May 16, 2024
d6c3739
merge in main
nichwch May 16, 2024
0ab245c
new chunk span validation schema
nichwch May 16, 2024
a320464
field for reason that validation failed for a given span
nichwch May 16, 2024
93bb781
add validated_chunk to ValidationResult
nichwch May 17, 2024
1381821
add helper method to get a list of error spans relative to llm output
nichwch May 17, 2024
3ccdda1
conceptual question
nichwch May 17, 2024
6fdbcd1
Merge branch 'main' into nichwch/chunk-accumulation-rewrite
nichwch May 17, 2024
f455ae2
Merge branch 'nichwch/chunk-accumulation-rewrite' into nichwch/stream…
nichwch May 17, 2024
74485eb
turn chunking_function into class method
nichwch May 17, 2024
a39b5af
incomplete tests for streaming chunk accumulation
nichwch May 17, 2024
0ae850e
format
nichwch May 20, 2024
847dd0a
remove print
nichwch May 20, 2024
f0b3030
fix a few bugs uncovered by testing
nichwch May 20, 2024
e8b6069
tests (WIP) for streaming
nichwch May 20, 2024
628e490
Merge branch 'main' into nichwch/chunk-accumulation-rewrite
nichwch May 21, 2024
a9a91a1
merge
nichwch May 21, 2024
eec8e19
base model
nichwch May 21, 2024
8726a28
optional typing to avoid breaking existing validators
nichwch May 21, 2024
ba68eb6
top level helper function for spans on guard, patch validated_chunk i…
nichwch May 21, 2024
2607423
attempt to use openai finish_reason field
nichwch May 21, 2024
da720c3
add comment explaining problem with using openai finish_message
nichwch May 21, 2024
8bdb292
test error span behavior
nichwch May 22, 2024
0abac83
address some changes
nichwch May 28, 2024
8f45a0a
handle case where llm callable doesnt provide finished flag
nichwch May 28, 2024
dfcd3b8
Merge pull request #771 from guardrails-ai/nichwch/streaming-error-spans
nichwch May 28, 2024
fe56871
Merge branch 'feat/streaming-update' into nichwch/chunk-accumulation-…
CalebCourier May 30, 2024
b52b8cb
lint, type, and test fixes
CalebCourier May 30, 2024
0aede77
use status for validation_passed in streaming
CalebCourier May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions guardrails/run/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,17 +547,29 @@ def validate(
index: int,
parsed_output: Any,
output_schema: Schema,
stream: Optional[bool] = False,
**kwargs,
):
"""Validate the output."""
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
attempt_number=index,
disable_tracer=self._disable_tracer,
**kwargs,
)
if isinstance(output_schema, StringSchema):
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
index,
self._disable_tracer,
stream,
**kwargs,
)
else:
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
attempt_number=index,
disable_tracer=self._disable_tracer,
**kwargs,
)

return validated_output

Expand Down
121 changes: 101 additions & 20 deletions guardrails/run/stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,47 +153,128 @@ def step(
verified = set()
# Loop over the stream
# and construct "fragments" of concatenated chunks
for chunk in stream:
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
fragment += chunk_text
# for now, handle string and json schema differently

# 2. Parse the fragment
parsed_fragment, move_to_next = self.parse(
index, fragment, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
if isinstance(output_schema, StringSchema):
for chunk in stream:
print('chunk', chunk)
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
fragment += chunk_text

# 2. Parse the chunk
parsed_chunk, move_to_next = self.parse(
index, chunk_text, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
validated_result = self.validate(
iteration,
index,
parsed_chunk,
output_schema,
True,
validate_subschema=True,
)
if isinstance(validated_result, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
)

# 3. Run output validation
validated_fragment = self.validate(
# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(
index, validated_result, output_schema
)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)
# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
# The chunk or the whole output?
raw_llm_output=chunk,
validated_output=validated_result,
validation_passed=validated_fragment is not None,
)
######################################
# need to validate remainder of chunks
######################################
remainder_validation = self.validate(
iteration,
index,
parsed_fragment,
"",
output_schema,
True,
validate_subschema=True,
remainder=True,
)
if isinstance(validated_fragment, SkeletonReAsk):
if isinstance(remainder_validation, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
)

# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(index, validated_fragment, output_schema)
reasks, valid_op = self.introspect(
index, remainder_validation, output_schema
)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)

# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
raw_llm_output=fragment,
validated_output=validated_fragment,
validation_passed=validated_fragment is not None,
# The chunk or the whole output?
raw_llm_output=chunk,
validated_output=remainder_validation,
validation_passed=remainder_validation is not None,
)
# handle non string schema
else:
for chunk in stream:
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
fragment += chunk_text

parsed_fragment, move_to_next = self.parse(
index, fragment, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
validated_fragment = self.validate(
iteration,
index,
parsed_fragment,
output_schema,
validate_subschema=True,
)
if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
)

# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(
index, validated_fragment, output_schema
)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)

# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
raw_llm_output=fragment,
validated_output=validated_fragment,
validation_passed=validated_fragment is not None,
)

# Finally, add to logs
iteration.outputs.raw_output = fragment
Expand Down
6 changes: 4 additions & 2 deletions guardrails/schema/string_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def validate(
metadata: Dict,
attempt_number: int = 0,
disable_tracer: Optional[bool] = True,
stream: Optional[bool] = False,
**kwargs,
) -> Any:
"""Validate a dictionary of data against the schema.
Expand All @@ -160,19 +161,20 @@ def validate(
dummy_key: data,
},
)

validated_response, metadata = validator_service.validate(
value=data,
metadata=metadata,
validator_setup=validation,
iteration=iteration,
disable_tracer=disable_tracer,
stream=stream,
**kwargs,
)

validated_response = {dummy_key: validated_response}

if check_refrain_in_dict(validated_response):
# If the data contains a `Refain` value, we return an empty
# If the data contains a `Refrain` value, we return an empty
# dictionary.
logger.debug("Refrain detected.")
validated_response = {}
Expand Down
64 changes: 64 additions & 0 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import nltk
from collections import defaultdict
from copy import deepcopy
from enum import Enum
Expand Down Expand Up @@ -175,6 +176,28 @@ class Refrain:
pass


# functions to get chunks
def split_word(chunk: str):
return list(map(lambda x: x + " ", chunk.split(" ")))[:-1]


def split_sentence(chunk: str):
# using the sentence tokenizer is expensive
# we check for a . to avoid wastefully calling the tokenizer
if "." not in chunk:
return []
sentences = nltk.sent_tokenize(chunk)
if len(sentences) == 0:
return []
# return the sentence
# then the remaining chunks that aren't finished accumulating
return [sentences[0], "".join(sentences[1:])]


def split_paragraph(chunk: str):
return list(map(lambda x: x + "\n", chunk.split("\n")))[:-1]


def check_refrain_in_list(schema: List) -> bool:
"""Checks if a Refrain object exists in a list.

Expand Down Expand Up @@ -390,6 +413,10 @@ class Validator(Runnable):

rail_alias: str = ""

# chunking function returns empty list or list of 2 chunks
# first chunk is the chunk to validate
# second chunk is incomplete chunk that needs further accumulation
accumulated_chunks = []
run_in_separate_process = False
override_value_on_pass = False
required_metadata_keys = []
Expand Down Expand Up @@ -448,10 +475,47 @@ def __init__(
self.rail_alias in validators_registry
), f"Validator {self.__class__.__name__} is not registered. "

def chunking_function(self, chunk: str):
return split_sentence(chunk)

def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
"""Validates a value and return a validation result."""
raise NotImplementedError

def validate_stream(
self, chunk: Any, metadata: Dict[str, Any], **kwargs
) -> ValidationResult:
"""Validates a chunk emitted by an LLM.
If the LLM chunk is smaller than the validator's chunking strategy,
it will be accumulated until it reaches the desired size. In the meantime,
the validator will return None.

If the LLM chunk is larger than the validator's chunking strategy,
it will split it into validator-sized chunks and validate each one,
returning an array of validation results.

Otherwise, the validator will validate the chunk and return the result.
"""
# combine accumulated chunks and new chunk
self.accumulated_chunks.append(chunk)
accumulated_text = "".join(self.accumulated_chunks)
# check if enough chunks have accumulated for validation
splitcontents = self.chunking_function(accumulated_text)

# if remainder kwargs is passed, validate remainder regardless
remainder = kwargs.get("remainder", False)
if remainder:
splitcontents = [accumulated_text, []]
if len(splitcontents) == 0:
return None
[chunk_to_validate, new_accumulated_chunks] = splitcontents
self.accumulated_chunks = new_accumulated_chunks
# exclude last chunk, because it may not be a complete chunk
validation_result = self.validate(chunk_to_validate, metadata)
# include the chunk that we've validated in the metadata
validation_result.metadata["validated_chunk"] = chunk_to_validate
return validation_result

def to_prompt(self, with_keywords: bool = True) -> str:
"""Convert the validator to a prompt.

Expand Down
Loading
Loading