Skip to content

Streaming chunk accumulation #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
29ee128
preliminary code and pseudocode
nichwch May 7, 2024
66a5333
add chunk accumulation strategy to Validator base class
nichwch May 7, 2024
7831466
handle case where llm chunk > validator chunk in validator class
nichwch May 8, 2024
2dbae2e
added some questions
nichwch May 8, 2024
1e3544d
change stream_runner to handle the result of iterable validate
nichwch May 8, 2024
e9084e1
format
nichwch May 8, 2024
1269f68
change validator base to use a chunking function instead of specifyin…
nichwch May 10, 2024
b454cd5
connect streaming all the way down call chain, include validated chun…
nichwch May 10, 2024
b64ab4e
change execute_validator to handle streaming
nichwch May 10, 2024
bf2bd32
make validate take stream parameter, remove validate_stream in top le…
nichwch May 10, 2024
c79e9b2
use wyatts sentence splitting strategy
nichwch May 10, 2024
4583cb9
import nltk
nichwch May 10, 2024
f1b4a88
use stream-enabled execute_validator
nichwch May 14, 2024
289745c
format
nichwch May 14, 2024
58d8eed
fix bug where json_schema was being called with streaming
nichwch May 15, 2024
947f476
conditionally use old logic for json_schema to avoid breaking json_sc…
nichwch May 16, 2024
8b2c154
validate remainders
nichwch May 16, 2024
d6c3739
merge in main
nichwch May 16, 2024
0ab245c
new chunk span validation schema
nichwch May 16, 2024
a320464
field for reason that validation failed for a given span
nichwch May 16, 2024
93bb781
add validated_chunk to ValidationResult
nichwch May 17, 2024
1381821
add helper method to get a list of error spans relative to llm output
nichwch May 17, 2024
3ccdda1
conceptual question
nichwch May 17, 2024
6fdbcd1
Merge branch 'main' into nichwch/chunk-accumulation-rewrite
nichwch May 17, 2024
f455ae2
Merge branch 'nichwch/chunk-accumulation-rewrite' into nichwch/stream…
nichwch May 17, 2024
74485eb
turn chunking_function into class method
nichwch May 17, 2024
a39b5af
incomplete tests for streaming chunk accumulation
nichwch May 17, 2024
0ae850e
format
nichwch May 20, 2024
847dd0a
remove print
nichwch May 20, 2024
f0b3030
fix a few bugs uncovered by testing
nichwch May 20, 2024
e8b6069
tests (WIP) for streaming
nichwch May 20, 2024
628e490
Merge branch 'main' into nichwch/chunk-accumulation-rewrite
nichwch May 21, 2024
a9a91a1
merge
nichwch May 21, 2024
eec8e19
base model
nichwch May 21, 2024
8726a28
optional typing to avoid breaking existing validators
nichwch May 21, 2024
ba68eb6
top level helper function for spans on guard, patch validated_chunk i…
nichwch May 21, 2024
2607423
attempt to use openai finish_reason field
nichwch May 21, 2024
da720c3
add comment explaining problem with using openai finish_message
nichwch May 21, 2024
8bdb292
test error span behavior
nichwch May 22, 2024
0abac83
address some changes
nichwch May 28, 2024
8f45a0a
handle case where llm callable doesnt provide finished flag
nichwch May 28, 2024
dfcd3b8
Merge pull request #771 from guardrails-ai/nichwch/streaming-error-spans
nichwch May 28, 2024
fe56871
Merge branch 'feat/streaming-update' into nichwch/chunk-accumulation-…
CalebCourier May 30, 2024
b52b8cb
lint, type, and test fixes
CalebCourier May 30, 2024
0aede77
use status for validation_passed in streaming
CalebCourier May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions guardrails/classes/history/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from guardrails.utils.logs_utils import ValidatorLogs
from guardrails.utils.pydantic_utils import ArbitraryModel
from guardrails.utils.reask_utils import ReAsk
from guardrails.validator_base import ErrorSpan


class Iteration(ArbitraryModel):
Expand Down Expand Up @@ -155,6 +156,12 @@ def failed_validations(self) -> List[ValidatorLogs]:
iteration."""
return self.outputs.failed_validations

@property
def error_spans_in_output(self) -> List[ErrorSpan]:
"""The error spans from the LLM response.
These indices are relative to the complete LLM output."""
return self.outputs.error_spans_in_output

@property
def status(self) -> str:
"""Representation of the end state of this iteration.
Expand Down
25 changes: 23 additions & 2 deletions guardrails/classes/history/outputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict, List, Optional, Sequence, Union

from pydantic import Field
from typing_extensions import deprecated

Expand All @@ -8,7 +7,7 @@
from guardrails.utils.logs_utils import ValidatorLogs
from guardrails.utils.pydantic_utils import ArbitraryModel
from guardrails.utils.reask_utils import ReAsk
from guardrails.validator_base import FailResult
from guardrails.validator_base import ErrorSpan, FailResult


class Outputs(ArbitraryModel):
Expand Down Expand Up @@ -75,6 +74,28 @@ def failed_validations(self) -> List[ValidatorLogs]:
]
)

@property
def error_spans_in_output(self) -> List[ErrorSpan]:
"""The error spans from the LLM response.
These indices are relative to the complete LLM output."""
total_len = 0
spans_in_output = []
for log in self.validator_logs:
result = log.validation_result
if isinstance(result, FailResult):
if result.error_spans is not None:
for error_span in result.error_spans:
spans_in_output.append(
ErrorSpan(
start=error_span.start + total_len,
end=error_span.end + total_len,
reason=error_span.reason,
)
)
if result.validated_chunk is not None:
total_len += len(result.validated_chunk)
return spans_in_output

@property
def status(self) -> str:
"""Representation of the end state of the validation run.
Expand Down
9 changes: 9 additions & 0 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,15 @@ async def _async_parse(

return ValidationOutcome[OT].from_guard_history(call)

def error_spans_in_output(self):
try:
call = self.history[0]
iter = call.iterations[0]
llm_spans = iter.error_spans_in_output
return llm_spans
except (AttributeError, TypeError):
return []

@deprecated(
"""The `with_prompt_validation` method is deprecated,
and will be removed in 0.5.x. Instead, please use
Expand Down
28 changes: 20 additions & 8 deletions guardrails/run/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,17 +547,29 @@ def validate(
index: int,
parsed_output: Any,
output_schema: Schema,
stream: Optional[bool] = False,
**kwargs,
):
"""Validate the output."""
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
attempt_number=index,
disable_tracer=self._disable_tracer,
**kwargs,
)
if isinstance(output_schema, StringSchema):
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
index,
self._disable_tracer,
stream,
**kwargs,
)
else:
validated_output = output_schema.validate(
iteration,
parsed_output,
self.metadata,
attempt_number=index,
disable_tracer=self._disable_tracer,
**kwargs,
)

return validated_output

Expand Down
165 changes: 130 additions & 35 deletions guardrails/run/stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,54 +153,149 @@ def step(
verified = set()
# Loop over the stream
# and construct "fragments" of concatenated chunks
for chunk in stream:
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
fragment += chunk_text
# for now, handle string and json schema differently

# 2. Parse the fragment
parsed_fragment, move_to_next = self.parse(
index, fragment, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
if isinstance(output_schema, StringSchema):
stream_finished = False
last_chunk_text = ""
for chunk in stream:
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
last_chunk_text = chunk_text
finished = self.is_last_chunk(chunk, api)
if finished:
stream_finished = True
fragment += chunk_text

# 3. Run output validation
validated_fragment = self.validate(
iteration,
index,
parsed_fragment,
output_schema,
validate_subschema=True,
)
if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
# 2. Parse the chunk
parsed_chunk, move_to_next = self.parse(
index, chunk_text, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
validated_text = self.validate(
iteration,
index,
parsed_chunk,
output_schema,
True,
validate_subschema=True,
# if it is the last chunk, validate everything that's left
remainder=finished,
)
if isinstance(validated_text, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
)

# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(index, validated_fragment, output_schema)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(index, validated_text, output_schema)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)
# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
# The chunk or the whole output?
raw_llm_output=chunk_text,
validated_output=validated_text,
validation_passed=validated_text is not None,
)
# handle case where generator doesn't give finished status
if not stream_finished:
last_result = self.validate(
iteration,
index,
"",
output_schema,
True,
validate_subschema=True,
remainder=True,
)
if len(last_result) > 0:
yield ValidationOutcome(
raw_llm_output=last_chunk_text,
validated_output=last_result,
validation_passed=last_result is not None,
)
# handle non string schema
else:
for chunk in stream:
# 1. Get the text from the chunk and append to fragment
chunk_text = self.get_chunk_text(chunk, api)
fragment += chunk_text

# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
raw_llm_output=fragment,
validated_output=validated_fragment,
validation_passed=validated_fragment is not None,
)
parsed_fragment, move_to_next = self.parse(
index, fragment, output_schema, verified
)
if move_to_next:
# Continue to next chunk
continue
validated_fragment = self.validate(
iteration,
index,
parsed_fragment,
output_schema,
validate_subschema=True,
)
if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
)

# 4. Introspect: inspect the validated fragment for reasks
reasks, valid_op = self.introspect(
index, validated_fragment, output_schema
)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)

# 5. Convert validated fragment to a pretty JSON string
yield ValidationOutcome(
raw_llm_output=fragment,
validated_output=validated_fragment,
validation_passed=validated_fragment is not None,
)

# Finally, add to logs
iteration.outputs.raw_output = fragment
iteration.outputs.parsed_output = parsed_fragment
iteration.outputs.validation_response = validated_fragment
iteration.outputs.guarded_output = valid_op

def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool:
"""Detect if chunk is final chunk"""
if isinstance(api, OpenAICallable):
if OPENAI_VERSION.startswith("0"):
finished = chunk["choices"][0]["finish_reason"]
return finished is not None
else:
finished = chunk.choices[0].finish_reason
return finished is not None
elif isinstance(api, OpenAIChatCallable):
if OPENAI_VERSION.startswith("0"):
finished = chunk["choices"][0]["finish_reason"]
return finished is not None
else:
finished = chunk.choices[0].finish_reason
return finished is not None
elif isinstance(api, LiteLLMCallable):
finished = chunk.choices[0].finish_reason
return finished is not None
else:
try:
finished = chunk.choices[0].finish_reason
return finished is not None
except (AttributeError, TypeError):
return False

def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str:
"""Get the text from a chunk."""
chunk_text = ""
Expand Down
6 changes: 4 additions & 2 deletions guardrails/schema/string_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def validate(
metadata: Dict,
attempt_number: int = 0,
disable_tracer: Optional[bool] = True,
stream: Optional[bool] = False,
**kwargs,
) -> Any:
"""Validate a dictionary of data against the schema.
Expand All @@ -160,19 +161,20 @@ def validate(
dummy_key: data,
},
)

validated_response, metadata = validator_service.validate(
value=data,
metadata=metadata,
validator_setup=validation,
iteration=iteration,
disable_tracer=disable_tracer,
stream=stream,
**kwargs,
)

validated_response = {dummy_key: validated_response}

if check_refrain_in_dict(validated_response):
# If the data contains a `Refain` value, we return an empty
# If the data contains a `Refrain` value, we return an empty
# dictionary.
logger.debug("Refrain detected.")
validated_response = {}
Expand Down
Loading
Loading