Skip to content

060 unify on messages and litellm #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions guardrails/actions/reask.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_reask_setup_for_string(
validation_response: Optional[Union[str, List, Dict, ReAsk]] = None,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
) -> Tuple[Dict[str, Any], Prompt, Instructions]:
) -> Tuple[Dict[str, Any], Messages]:
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()

Expand Down Expand Up @@ -300,7 +300,10 @@ def get_reask_setup_for_string(
messages = Messages(exec_options.reask_messages)
if messages is None:
messages = Messages(
[{"role": "system", "content": "You are a helpful assistant."}]
[
{"role": "system", "content": instructions},
{"role": "user", "content": prompt},
]
)

messages = messages.format(
Expand All @@ -309,7 +312,17 @@ def get_reask_setup_for_string(
**prompt_params,
)

return output_schema, prompt, instructions
return output_schema, messages


def get_original_messages(exec_options: GuardExecutionOptions) -> List[Dict[str, Any]]:
exec_options = exec_options or GuardExecutionOptions()
original_messages = exec_options.messages or []
messages_prompt = next(
(h.get("content") for h in original_messages if isinstance(h, dict)),
"",
)
return messages_prompt


def get_original_prompt(exec_options: Optional[GuardExecutionOptions] = None) -> str:
Expand Down Expand Up @@ -338,7 +351,7 @@ def get_reask_setup_for_json(
use_full_schema: Optional[bool] = False,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
) -> Tuple[Dict[str, Any], Prompt, Instructions]:
) -> Tuple[Dict[str, Any], Messages]:
reask_schema = output_schema
is_skeleton_reask = not any(isinstance(reask, FieldReAsk) for reask in reasks)
is_nonparseable_reask = any(
Expand All @@ -347,12 +360,10 @@ def get_reask_setup_for_json(
error_messages = {}
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()
original_prompt = get_original_prompt(exec_options)
use_xml = prompt_uses_xml(original_prompt)
original_messages = get_original_messages(exec_options)
use_xml = prompt_uses_xml(original_messages)

reask_prompt_template = None
if exec_options.reask_prompt:
reask_prompt_template = Prompt(exec_options.reask_prompt)

if is_nonparseable_reask:
if reask_prompt_template is None:
Expand Down Expand Up @@ -461,31 +472,26 @@ def reask_decoder(obj: ReAsk):
**prompt_params,
)

instructions = None
if exec_options.reask_instructions:
instructions = Instructions(exec_options.reask_instructions)
else:
instructions_const = (
constants["high_level_xml_instructions"]
if use_xml
else constants["high_level_json_instructions"]
)
instructions = Instructions(instructions_const)
instructions_const = (
constants["high_level_xml_instructions"]
if use_xml
else constants["high_level_json_instructions"]
)
instructions = Instructions(instructions_const)
instructions = instructions.format(**prompt_params)

# TODO: enable this in 0.6.0
# messages = None
# if exec_options.reask_messages:
# messages = Messages(exec_options.reask_messages)
# else:
# messages = Messages(
# [
# {"role": "system", "content": instructions},
# {"role": "user", "content": prompt},
# ]
# )
messages = None
if exec_options.reask_messages:
messages = Messages(exec_options.reask_messages)
else:
messages = Messages(
[
{"role": "system", "content": instructions},
{"role": "user", "content": prompt},
]
)

return reask_schema, prompt, instructions
return reask_schema, messages


def get_reask_setup(
Expand All @@ -499,7 +505,7 @@ def get_reask_setup(
use_full_schema: Optional[bool] = False,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
) -> Tuple[Dict[str, Any], Prompt, Instructions]:
) -> Tuple[Dict[str, Any], Messages]:
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()

Expand Down
121 changes: 30 additions & 91 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,17 @@ def from_pydantic(
cls,
output_class: ModelOrListOfModels,
*,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
num_reasks: Optional[int] = None,
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
messages: Optional[List[Dict]] = None,
reask_messages: Optional[List[Dict]] = None,
tracer: Optional[Tracer] = None,
name: Optional[str] = None,
description: Optional[str] = None,
):
guard = super().from_pydantic(
output_class,
prompt=prompt,
instructions=instructions,
messages=messages,
num_reasks=num_reasks,
reask_prompt=reask_prompt,
reask_instructions=reask_instructions,
reask_messages=reask_messages,
tracer=tracer,
name=name,
Expand All @@ -125,10 +119,8 @@ def from_string(
validators: Sequence[Validator],
*,
string_description: Optional[str] = None,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
messages: Optional[List[Dict]] = None,
reask_messages: Optional[List[Dict]] = None,
num_reasks: Optional[int] = None,
tracer: Optional[Tracer] = None,
name: Optional[str] = None,
Expand All @@ -137,10 +129,8 @@ def from_string(
guard = super().from_string(
validators,
string_description=string_description,
prompt=prompt,
instructions=instructions,
reask_prompt=reask_prompt,
reask_instructions=reask_instructions,
messages=messages,
reask_messages=reask_messages,
num_reasks=num_reasks,
tracer=tracer,
name=name,
Expand Down Expand Up @@ -178,9 +168,7 @@ async def _execute(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
metadata: Optional[Dict],
full_schema_reask: Optional[bool] = None,
**kwargs,
Expand All @@ -192,10 +180,8 @@ async def _execute(
self._fill_validator_map()
self._fill_validators()
metadata = metadata or {}
if not llm_output and llm_api and not (prompt or msg_history):
raise RuntimeError(
"'prompt' or 'msg_history' must be provided in order to call an LLM!"
)
if not llm_output and llm_api and not (messages):
raise RuntimeError("'messages' must be provided in order to call an LLM!")
# check if validator requirements are fulfilled
missing_keys = verify_metadata_requirements(metadata, self._validators)
if missing_keys:
Expand All @@ -210,9 +196,7 @@ async def __exec(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
Expand Down Expand Up @@ -245,14 +229,6 @@ async def __exec(
("guard_id", self.id),
("user_id", self._user_id),
("llm_api", llm_api_str),
(
"custom_reask_prompt",
self._exec_opts.reask_prompt is not None,
),
(
"custom_reask_instructions",
self._exec_opts.reask_instructions is not None,
),
(
"custom_reask_messages",
self._exec_opts.reask_messages is not None,
Expand All @@ -273,14 +249,10 @@ async def __exec(
"This should never happen."
)

input_prompt = prompt or self._exec_opts.prompt
input_instructions = instructions or self._exec_opts.instructions
input_messages = messages or self._exec_opts.messages
call_inputs = CallInputs(
llm_api=llm_api,
prompt=input_prompt,
instructions=input_instructions,
msg_history=msg_history,
prompt_params=prompt_params,
messages=input_messages,
num_reasks=self._num_reasks,
metadata=metadata,
full_schema_reask=full_schema_reask,
Expand All @@ -298,9 +270,7 @@ async def __exec(
prompt_params=prompt_params,
metadata=metadata,
full_schema_reask=full_schema_reask,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
*args,
**kwargs,
)
Expand All @@ -315,9 +285,7 @@ async def __exec(
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=self._num_reasks,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
call_log=call_log,
Expand All @@ -343,9 +311,7 @@ async def __exec(
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=num_reasks,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
*args,
Expand All @@ -362,9 +328,7 @@ async def _exec(
num_reasks: int = 0, # Should be defined at this point
metadata: Dict, # Should be defined at this point
full_schema_reask: bool = False, # Should be defined at this point
prompt: Optional[str],
instructions: Optional[str],
msg_history: Optional[List[Dict]],
messages: List[Dict] = None,
**kwargs,
) -> Union[
ValidationOutcome[OT],
Expand All @@ -377,9 +341,7 @@ async def _exec(
llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate)
prompt_params: The parameters to pass to the prompt.format() method.
num_reasks: The max times to re-ask the LLM for invalid output.
prompt: The prompt to use for the LLM.
instructions: Instructions for chat models.
msg_history: The message history to pass to the LLM.
messages: List of messages for llm to respond to.
metadata: Metadata to pass to the validators.
full_schema_reask: When reasking, whether to regenerate the full schema
or just the incorrect values.
Expand All @@ -396,9 +358,7 @@ async def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
api=api,
metadata=metadata,
output=llm_output,
Expand All @@ -418,9 +378,7 @@ async def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
api=api,
metadata=metadata,
output=llm_output,
Expand All @@ -441,9 +399,7 @@ async def __call__(
*args,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = 1,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
Expand All @@ -460,9 +416,7 @@ async def __call__(
(e.g. openai.completions.create or openai.chat.completions.create)
prompt_params: The parameters to pass to the prompt.format() method.
num_reasks: The max times to re-ask the LLM for invalid output.
prompt: The prompt to use for the LLM.
instructions: Instructions for chat models.
msg_history: The message history to pass to the LLM.
messages: The messages to pass to the LLM.
metadata: Metadata to pass to the validators.
full_schema_reask: When reasking, whether to regenerate the full schema
or just the incorrect values.
Expand All @@ -473,25 +427,18 @@ async def __call__(
The raw text output from the LLM and the validated output.
"""

instructions = instructions or self._exec_opts.instructions
prompt = prompt or self._exec_opts.prompt
msg_history = msg_history or kwargs.pop("messages", None) or []

if prompt is None:
if msg_history is not None and not len(msg_history):
raise RuntimeError(
"You must provide a prompt if msg_history is empty. "
"Alternatively, you can provide a prompt in the Schema constructor."
)
messages = messages or self._exec_opts.messages or []
if messages is not None and not len(messages):
raise RuntimeError(
"You must provide messages to the LLM in order to call it."
)

return await self._execute(
*args,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
Expand Down Expand Up @@ -534,24 +481,16 @@ async def parse(
if llm_api is None
else 1
)
default_prompt = self._exec_opts.prompt if llm_api is not None else None
prompt = kwargs.pop("prompt", default_prompt)

default_instructions = self._exec_opts.instructions if llm_api else None
instructions = kwargs.pop("instructions", default_instructions)

default_msg_history = self._exec_opts.msg_history if llm_api else None
msg_history = kwargs.pop("msg_history", default_msg_history)
default_messages = self._exec_opts.messages if llm_api else None
messages = kwargs.pop("messages", default_messages)

return await self._execute( # type: ignore
*args,
llm_output=llm_output,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=final_num_reasks,
prompt=prompt,
instructions=instructions,
msg_history=msg_history,
messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
Expand Down
Loading
Loading