Skip to content

Fix Async Stream Contexts #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextvars import ContextVar, copy_context
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -118,6 +119,10 @@ async def async_step(
refrain_triggered = False
validation_passed = True

ctx_accumulated_chunks = ContextVar("accumulated_chunks")
ctx_accumulated_chunks.set([])
context = copy_context()

if self.output_type == OutputTypes.STRING:
validator_service = AsyncValidatorService(self.disable_tracer)
async for chunk in stream_output:
Expand All @@ -134,6 +139,8 @@ async def async_step(
"$",
"$",
True,
context=context,
ctx_accumulated_chunks=ctx_accumulated_chunks,
)
validators = self.validation_map.get("$", [])

Expand Down Expand Up @@ -240,6 +247,8 @@ async def async_step(
parsed_fragment,
output_schema,
validate_subschema=True,
context=context,
ctx_accumulated_chunks=ctx_accumulated_chunks,
)
if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
Expand Down Expand Up @@ -275,6 +284,9 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
"""Get the text from a chunk."""
chunk_text = ""

if not chunk.choices or len(chunk.choices) == 0:
return chunk_text

try:
finished = chunk.choices[0].finish_reason
content = chunk.choices[0].delta.content
Expand Down
5 changes: 5 additions & 0 deletions guardrails/run/stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def prepare_chunk_generator(stream) -> Iterator[Tuple[Any, bool]]:
def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool:
"""Detect if chunk is final chunk."""
try:
if (
not chunk.choices or len(chunk.choices) == 0
) and chunk.usage is not None:
# This is the last extra chunk for usage statistics
return True
finished = chunk.choices[0].finish_reason
return finished is not None
except (AttributeError, TypeError):
Expand Down
38 changes: 29 additions & 9 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - [ ] Remove validator_base.py in 0.6.x

import asyncio
import contextlib
from contextvars import Context, ContextVar
from functools import partial
import inspect
import logging
Expand Down Expand Up @@ -67,10 +67,8 @@ def split_sentence_word_tokenizers_jl_separator(
# we check for a . to avoid wastefully calling the tokenizer

# check at least 3 characters have been accumulated before splitting
is_minimum_length = False
with contextlib.suppress(IndexError):
chunk[2]
is_minimum_length = True
third_chunk = safe_get(chunk, 2)
is_minimum_length = third_chunk is not None

# check for potential line endings, which is what split_sentences does
chunk_with_potential_line_endings, count = re.subn(
Expand Down Expand Up @@ -292,7 +290,13 @@ def _chunking_function(self, chunk: str) -> List[str]:
return split_sentence_word_tokenizers_jl_separator(chunk)

def validate_stream(
self, chunk: Any, metadata: Dict[str, Any], **kwargs
self,
chunk: Any,
metadata: Dict[str, Any],
*,
ctx_accumulated_chunks: Optional[ContextVar[List[str]]] = None,
context: Optional[Context] = None,
**kwargs,
) -> Optional[ValidationResult]:
"""Validates a chunk emitted by an LLM. If the LLM chunk is smaller
than the validator's chunking strategy, it will be accumulated until it
Expand All @@ -307,8 +311,13 @@ def validate_stream(
result.
"""
# combine accumulated chunks and new [:-1]chunk
self.accumulated_chunks.append(chunk)
accumulated_text = "".join(self.accumulated_chunks)
accumulated_chunks = (
context.run(ctx_accumulated_chunks.get)
if ctx_accumulated_chunks and context
else self.accumulated_chunks
)
accumulated_chunks.append(chunk)
accumulated_text = "".join(accumulated_chunks)
# check if enough chunks have accumulated for validation
split_contents = self._chunking_function(accumulated_text)

Expand All @@ -318,9 +327,16 @@ def validate_stream(
split_contents = [accumulated_text, ""]
# if no chunks are returned, we haven't accumulated enough
if len(split_contents) == 0:
if ctx_accumulated_chunks and context:
context.run(ctx_accumulated_chunks.set, accumulated_chunks)
else:
self.accumulated_chunks = accumulated_chunks
return None
[chunk_to_validate, new_accumulated_chunks] = split_contents
self.accumulated_chunks = [new_accumulated_chunks]
if ctx_accumulated_chunks and context:
context.run(ctx_accumulated_chunks.set, [new_accumulated_chunks])
else:
self.accumulated_chunks = [new_accumulated_chunks]
# exclude last chunk, because it may not be a complete chunk
validation_result = self.validate(chunk_to_validate, metadata)
# if validate doesn't set validated chunk, we set it
Expand All @@ -336,6 +352,10 @@ def validate_stream(
)
]

if ctx_accumulated_chunks:
ctx_accumulated_chunks.set(accumulated_chunks)
else:
self.accumulated_chunks = accumulated_chunks
return validation_result

async def async_validate_stream(
Expand Down
Loading