diff --git a/guardrails/actions/reask.py b/guardrails/actions/reask.py
index 2ca1f711b..629cab514 100644
--- a/guardrails/actions/reask.py
+++ b/guardrails/actions/reask.py
@@ -247,7 +247,7 @@ def get_reask_setup_for_string(
validation_response: Optional[Union[str, List, Dict, ReAsk]] = None,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
-) -> Tuple[Dict[str, Any], Prompt, Instructions]:
+) -> Tuple[Dict[str, Any], Messages]:
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()
@@ -300,7 +300,10 @@ def get_reask_setup_for_string(
messages = Messages(exec_options.reask_messages)
if messages is None:
messages = Messages(
- [{"role": "system", "content": "You are a helpful assistant."}]
+ [
+ {"role": "system", "content": instructions},
+ {"role": "user", "content": prompt},
+ ]
)
messages = messages.format(
@@ -309,7 +312,17 @@ def get_reask_setup_for_string(
**prompt_params,
)
- return output_schema, prompt, instructions
+ return output_schema, messages
+
+
+def get_original_messages(exec_options: GuardExecutionOptions) -> List[Dict[str, Any]]:
+ exec_options = exec_options or GuardExecutionOptions()
+ original_messages = exec_options.messages or []
+ messages_prompt = next(
+ (h.get("content") for h in original_messages if isinstance(h, dict)),
+ "",
+ )
+ return messages_prompt
def get_original_prompt(exec_options: Optional[GuardExecutionOptions] = None) -> str:
@@ -338,7 +351,7 @@ def get_reask_setup_for_json(
use_full_schema: Optional[bool] = False,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
-) -> Tuple[Dict[str, Any], Prompt, Instructions]:
+) -> Tuple[Dict[str, Any], Messages]:
reask_schema = output_schema
is_skeleton_reask = not any(isinstance(reask, FieldReAsk) for reask in reasks)
is_nonparseable_reask = any(
@@ -347,12 +360,10 @@ def get_reask_setup_for_json(
error_messages = {}
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()
- original_prompt = get_original_prompt(exec_options)
- use_xml = prompt_uses_xml(original_prompt)
+ original_messages = get_original_messages(exec_options)
+ use_xml = prompt_uses_xml(original_messages)
reask_prompt_template = None
- if exec_options.reask_prompt:
- reask_prompt_template = Prompt(exec_options.reask_prompt)
if is_nonparseable_reask:
if reask_prompt_template is None:
@@ -461,31 +472,26 @@ def reask_decoder(obj: ReAsk):
**prompt_params,
)
- instructions = None
- if exec_options.reask_instructions:
- instructions = Instructions(exec_options.reask_instructions)
- else:
- instructions_const = (
- constants["high_level_xml_instructions"]
- if use_xml
- else constants["high_level_json_instructions"]
- )
- instructions = Instructions(instructions_const)
+ instructions_const = (
+ constants["high_level_xml_instructions"]
+ if use_xml
+ else constants["high_level_json_instructions"]
+ )
+ instructions = Instructions(instructions_const)
instructions = instructions.format(**prompt_params)
- # TODO: enable this in 0.6.0
- # messages = None
- # if exec_options.reask_messages:
- # messages = Messages(exec_options.reask_messages)
- # else:
- # messages = Messages(
- # [
- # {"role": "system", "content": instructions},
- # {"role": "user", "content": prompt},
- # ]
- # )
+ messages = None
+ if exec_options.reask_messages:
+ messages = Messages(exec_options.reask_messages)
+ else:
+ messages = Messages(
+ [
+ {"role": "system", "content": instructions},
+ {"role": "user", "content": prompt},
+ ]
+ )
- return reask_schema, prompt, instructions
+ return reask_schema, messages
def get_reask_setup(
@@ -499,7 +505,7 @@ def get_reask_setup(
use_full_schema: Optional[bool] = False,
prompt_params: Optional[Dict[str, Any]] = None,
exec_options: Optional[GuardExecutionOptions] = None,
-) -> Tuple[Dict[str, Any], Prompt, Instructions]:
+) -> Tuple[Dict[str, Any], Messages]:
prompt_params = prompt_params or {}
exec_options = exec_options or GuardExecutionOptions()
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
index f8f12d085..1c727d8b4 100644
--- a/guardrails/async_guard.py
+++ b/guardrails/async_guard.py
@@ -92,11 +92,8 @@ def from_pydantic(
cls,
output_class: ModelOrListOfModels,
*,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
num_reasks: Optional[int] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
+ messages: Optional[List[Dict]] = None,
reask_messages: Optional[List[Dict]] = None,
tracer: Optional[Tracer] = None,
name: Optional[str] = None,
@@ -104,11 +101,8 @@ def from_pydantic(
):
guard = super().from_pydantic(
output_class,
- prompt=prompt,
- instructions=instructions,
+ messages=messages,
num_reasks=num_reasks,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
reask_messages=reask_messages,
tracer=tracer,
name=name,
@@ -125,10 +119,8 @@ def from_string(
validators: Sequence[Validator],
*,
string_description: Optional[str] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
+ messages: Optional[List[Dict]] = None,
+ reask_messages: Optional[List[Dict]] = None,
num_reasks: Optional[int] = None,
tracer: Optional[Tracer] = None,
name: Optional[str] = None,
@@ -137,10 +129,8 @@ def from_string(
guard = super().from_string(
validators,
string_description=string_description,
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
+ messages=messages,
+ reask_messages=reask_messages,
num_reasks=num_reasks,
tracer=tracer,
name=name,
@@ -178,9 +168,7 @@ async def _execute(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
metadata: Optional[Dict],
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -192,10 +180,8 @@ async def _execute(
self._fill_validator_map()
self._fill_validators()
metadata = metadata or {}
- if not llm_output and llm_api and not (prompt or msg_history):
- raise RuntimeError(
- "'prompt' or 'msg_history' must be provided in order to call an LLM!"
- )
+ if not llm_output and llm_api and not (messages):
+ raise RuntimeError("'messages' must be provided in order to call an LLM!")
# check if validator requirements are fulfilled
missing_keys = verify_metadata_requirements(metadata, self._validators)
if missing_keys:
@@ -210,9 +196,7 @@ async def __exec(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -245,14 +229,6 @@ async def __exec(
("guard_id", self.id),
("user_id", self._user_id),
("llm_api", llm_api_str),
- (
- "custom_reask_prompt",
- self._exec_opts.reask_prompt is not None,
- ),
- (
- "custom_reask_instructions",
- self._exec_opts.reask_instructions is not None,
- ),
(
"custom_reask_messages",
self._exec_opts.reask_messages is not None,
@@ -273,14 +249,10 @@ async def __exec(
"This should never happen."
)
- input_prompt = prompt or self._exec_opts.prompt
- input_instructions = instructions or self._exec_opts.instructions
+ input_messages = messages or self._exec_opts.messages
call_inputs = CallInputs(
llm_api=llm_api,
- prompt=input_prompt,
- instructions=input_instructions,
- msg_history=msg_history,
- prompt_params=prompt_params,
+ messages=input_messages,
num_reasks=self._num_reasks,
metadata=metadata,
full_schema_reask=full_schema_reask,
@@ -298,9 +270,7 @@ async def __exec(
prompt_params=prompt_params,
metadata=metadata,
full_schema_reask=full_schema_reask,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
*args,
**kwargs,
)
@@ -315,9 +285,7 @@ async def __exec(
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=self._num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
call_log=call_log,
@@ -343,9 +311,7 @@ async def __exec(
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
*args,
@@ -362,9 +328,7 @@ async def _exec(
num_reasks: int = 0, # Should be defined at this point
metadata: Dict, # Should be defined at this point
full_schema_reask: bool = False, # Should be defined at this point
- prompt: Optional[str],
- instructions: Optional[str],
- msg_history: Optional[List[Dict]],
+ messages: List[Dict] = None,
**kwargs,
) -> Union[
ValidationOutcome[OT],
@@ -377,9 +341,7 @@ async def _exec(
llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate)
prompt_params: The parameters to pass to the prompt.format() method.
num_reasks: The max times to re-ask the LLM for invalid output.
- prompt: The prompt to use for the LLM.
- instructions: Instructions for chat models.
- msg_history: The message history to pass to the LLM.
+ messages: List of messages for llm to respond to.
metadata: Metadata to pass to the validators.
full_schema_reask: When reasking, whether to regenerate the full schema
or just the incorrect values.
@@ -396,9 +358,7 @@ async def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
api=api,
metadata=metadata,
output=llm_output,
@@ -418,9 +378,7 @@ async def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
api=api,
metadata=metadata,
output=llm_output,
@@ -441,9 +399,7 @@ async def __call__(
*args,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = 1,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -460,9 +416,7 @@ async def __call__(
(e.g. openai.completions.create or openai.chat.completions.create)
prompt_params: The parameters to pass to the prompt.format() method.
num_reasks: The max times to re-ask the LLM for invalid output.
- prompt: The prompt to use for the LLM.
- instructions: Instructions for chat models.
- msg_history: The message history to pass to the LLM.
+ messages: The messages to pass to the LLM.
metadata: Metadata to pass to the validators.
full_schema_reask: When reasking, whether to regenerate the full schema
or just the incorrect values.
@@ -473,25 +427,18 @@ async def __call__(
The raw text output from the LLM and the validated output.
"""
- instructions = instructions or self._exec_opts.instructions
- prompt = prompt or self._exec_opts.prompt
- msg_history = msg_history or kwargs.pop("messages", None) or []
-
- if prompt is None:
- if msg_history is not None and not len(msg_history):
- raise RuntimeError(
- "You must provide a prompt if msg_history is empty. "
- "Alternatively, you can provide a prompt in the Schema constructor."
- )
+ messages = messages or self._exec_opts.messages or []
+ if messages is not None and not len(messages):
+ raise RuntimeError(
+ "You must provide messages to the LLM in order to call it."
+ )
return await self._execute(
*args,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
@@ -534,14 +481,8 @@ async def parse(
if llm_api is None
else 1
)
- default_prompt = self._exec_opts.prompt if llm_api is not None else None
- prompt = kwargs.pop("prompt", default_prompt)
-
- default_instructions = self._exec_opts.instructions if llm_api else None
- instructions = kwargs.pop("instructions", default_instructions)
-
- default_msg_history = self._exec_opts.msg_history if llm_api else None
- msg_history = kwargs.pop("msg_history", default_msg_history)
+ default_messages = self._exec_opts.messages if llm_api else None
+ messages = kwargs.pop("messages", default_messages)
return await self._execute( # type: ignore
*args,
@@ -549,9 +490,7 @@ async def parse(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=final_num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
diff --git a/guardrails/classes/execution/guard_execution_options.py b/guardrails/classes/execution/guard_execution_options.py
index 1230cf7d8..cabdddbcf 100644
--- a/guardrails/classes/execution/guard_execution_options.py
+++ b/guardrails/classes/execution/guard_execution_options.py
@@ -4,11 +4,6 @@
@dataclass
class GuardExecutionOptions:
- prompt: Optional[str] = None
- instructions: Optional[str] = None
- msg_history: Optional[List[Dict]] = None
messages: Optional[List[Dict]] = None
- reask_prompt: Optional[str] = None
- reask_instructions: Optional[str] = None
reask_messages: Optional[List[Dict]] = None
num_reasks: Optional[int] = None
diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py
index cf79a6e2f..3eacc46d4 100644
--- a/guardrails/classes/history/call.py
+++ b/guardrails/classes/history/call.py
@@ -15,7 +15,6 @@
from guardrails.classes.generic.arbitrary_model import ArbitraryModel
from guardrails.classes.validation.validation_result import ValidationResult
from guardrails.constants import error_status, fail_status, not_run_status, pass_status
-from guardrails.prompt.instructions import Instructions
from guardrails.prompt.prompt import Prompt
from guardrails.prompt.messages import Messages
from guardrails.classes.validation.validator_logs import ValidatorLogs
@@ -116,22 +115,22 @@ def reask_prompts(self) -> Stack[Optional[str]]:
return Stack()
@property
- def instructions(self) -> Optional[str]:
- """The instructions as provided by the user when initializing or
- calling the Guard."""
- return self.inputs.instructions
+ def messages(self) -> Optional[List[Dict[str, Any]]]:
+ """The messages as provided by the user when initializing or calling the
+ Guard."""
+ return self.inputs.messages
@property
- def compiled_instructions(self) -> Optional[str]:
- """The initial compiled instructions that were passed to the LLM on the
+ def compiled_messages(self) -> Optional[List[Dict[str, Any]]]:
+ """The initial compiled messages that were passed to the LLM on the
first call."""
if self.iterations.empty():
return None
- initial_inputs = self.iterations.first.inputs # type: ignore
- instructions: Instructions = initial_inputs.instructions # type: ignore
+ initial_inputs = self.iterations.first.inputs
+ messages = initial_inputs.messages
prompt_params = initial_inputs.prompt_params or {}
- if instructions is not None:
- return instructions.format(**prompt_params).source
+ if messages is not None:
+ return messages.format(**prompt_params).source
@property
def reask_messages(self) -> Stack[Messages]:
@@ -145,27 +144,7 @@ def reask_messages(self) -> Stack[Messages]:
reasks.remove(initial_messages) # type: ignore
return Stack(
*[
- r.inputs.messages if r.inputs.messages is not None else None
- for r in reasks
- ]
- )
-
- return Stack()
-
- @property
- def reask_instructions(self) -> Stack[str]:
- """The compiled instructions used during reasks.
-
- Does not include the initial instructions.
- """
- if self.iterations.length > 0:
- reasks = self.iterations.copy()
- reasks.remove(reasks.first) # type: ignore
- return Stack(
- *[
- r.inputs.instructions.source
- if r.inputs.instructions is not None
- else None
+ r.inputs.messages.source if r.inputs.messages is not None else None
for r in reasks
]
)
diff --git a/guardrails/classes/history/call_inputs.py b/guardrails/classes/history/call_inputs.py
index 4aa2363f8..5f22a817c 100644
--- a/guardrails/classes/history/call_inputs.py
+++ b/guardrails/classes/history/call_inputs.py
@@ -28,11 +28,8 @@ class CallInputs(Inputs, ICallInputs, ArbitraryModel):
"during Guard.__call__ or Guard.parse.",
default=None,
)
- prompt: Optional[str] = Field(
- description="The prompt string as provided by the user.", default=None
- )
- instructions: Optional[str] = Field(
- description="The instructions string as provided by the user.", default=None
+ messages: Optional[List[Dict[str, Any]]] = Field(
+ description="The messages as provided by the user.", default=None
)
args: List[Any] = Field(
description="Additional arguments for the LLM as provided by the user.",
diff --git a/guardrails/classes/history/inputs.py b/guardrails/classes/history/inputs.py
index d44d63628..c8b1e05d3 100644
--- a/guardrails/classes/history/inputs.py
+++ b/guardrails/classes/history/inputs.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Optional
from pydantic import Field
@@ -18,10 +18,7 @@ class Inputs(IInputs, ArbitraryModel):
for calling the LLM.
llm_output (Optional[str]): The string output from an
external LLM call provided by the user via Guard.parse.
- instructions (Optional[Instructions]): The constructed
- Instructions class for chat model calls.
- prompt (Optional[Prompt]): The constructed Prompt class.
- msg_history (Optional[List[Dict]]): The message history
+ messages (Optional[List[Dict]]): The message history
provided by the user for chat model calls.
prompt_params (Optional[Dict]): The parameters provided
by the user that will be formatted into the final LLM prompt.
@@ -42,19 +39,8 @@ class Inputs(IInputs, ArbitraryModel):
"provided by the user via Guard.parse.",
default=None,
)
- instructions: Optional[Instructions] = Field(
- description="The constructed Instructions class for chat model calls.",
- default=None,
- )
- prompt: Optional[Prompt] = Field(
- description="The constructed Prompt class.", default=None
- )
- msg_history: Optional[List[Dict]] = Field(
- description="The message history provided by the user for chat model calls.",
- default=None,
- )
- messages: Optional[List[Messages]] = Field(
- description="The message history provided by the user for chat model calls.",
+ messages: Optional[Messages] = Field(
+ description="The messages provided by the user for chat model calls.",
default=None,
)
prompt_params: Optional[Dict] = Field(
diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py
index 45c60baab..8760d2e49 100644
--- a/guardrails/classes/history/iteration.py
+++ b/guardrails/classes/history/iteration.py
@@ -188,63 +188,36 @@ def status(self) -> str:
@property
def rich_group(self) -> Group:
- def create_msg_history_table(
- msg_history: Optional[List[Dict[str, Prompt]]],
+ def create_messages_table(
+ messages: Optional[List[Dict[str, Prompt]]],
) -> Union[str, Table]:
- if msg_history is None:
- return "No message history."
+ if messages is None:
+ return "No messages."
table = Table(show_lines=True)
table.add_column("Role", justify="right", no_wrap=True)
table.add_column("Content")
- for msg in msg_history:
- table.add_row(str(msg["role"]), msg["content"].source)
+ for msg in messages:
+ if isinstance(msg["content"], str):
+ table.add_row(str(msg["role"]), msg["content"])
+ else:
+ table.add_row(str(msg["role"]), msg["content"].source)
return table
- table = create_msg_history_table(self.inputs.msg_history)
-
- if self.inputs.instructions is not None:
- return Group(
- Panel(
- self.inputs.prompt.source if self.inputs.prompt else "No prompt",
- title="Prompt",
- style="on #F0F8FF",
- ),
- Panel(
- self.inputs.instructions.source,
- title="Instructions",
- style="on #FFF0F2",
- ),
- Panel(table, title="Message History", style="on #E7DFEB"),
- Panel(
- self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC"
- ),
- Panel(
- pretty_repr(self.validation_response),
- title="Validated Output",
- style="on #F0FFF0",
- ),
- )
- else:
- return Group(
- Panel(
- self.inputs.prompt.source if self.inputs.prompt else "No prompt",
- title="Prompt",
- style="on #F0F8FF",
- ),
- Panel(table, title="Message History", style="on #E7DFEB"),
- Panel(
- self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC"
- ),
- Panel(
- self.validation_response
- if isinstance(self.validation_response, str)
- else pretty_repr(self.validation_response),
- title="Validated Output",
- style="on #F0FFF0",
- ),
- )
+ table = create_messages_table(self.inputs.messages)
+
+ return Group(
+ Panel(table, title="Messages", style="on #E7DFEB"),
+ Panel(self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC"),
+ Panel(
+ self.validation_response
+ if isinstance(self.validation_response, str)
+ else pretty_repr(self.validation_response),
+ title="Validated Output",
+ style="on #F0FFF0",
+ ),
+ )
def __str__(self) -> str:
return pretty_repr(self)
diff --git a/guardrails/guard.py b/guardrails/guard.py
index fcd40f022..84a25bcd4 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -310,26 +310,17 @@ def _fill_exec_opts(
self,
*,
num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
+ messages: Optional[List[Dict]] = None,
+ reask_messages: Optional[List[Dict]] = None,
**kwargs, # noqa
):
"""Backfill execution options from kwargs."""
if num_reasks is not None:
self._exec_opts.num_reasks = num_reasks
- if prompt is not None:
- self._exec_opts.prompt = prompt
- if instructions is not None:
- self._exec_opts.instructions = instructions
- if msg_history is not None:
- self._exec_opts.msg_history = msg_history
- if reask_prompt is not None:
- self._exec_opts.reask_prompt = reask_prompt
- if reask_instructions is not None:
- self._exec_opts.reask_instructions = reask_instructions
+ if messages is not None:
+ self._exec_opts.messages = messages
+ if reask_messages is not None:
+ self._exec_opts.reask_messages = reask_messages
@classmethod
def _from_rail_schema(
@@ -402,6 +393,7 @@ def from_rail(
cls._set_tracer(cls, tracer) # type: ignore
schema = rail_file_to_schema(rail_file)
+ print("==== rail schema", schema)
return cls._from_rail_schema(
schema,
rail=rail_file,
@@ -451,6 +443,7 @@ def from_rail_string(
cls._set_tracer(cls, tracer) # type: ignore
schema = rail_string_to_schema(rail_string)
+
return cls._from_rail_schema(
schema,
rail=rail_string,
@@ -465,11 +458,7 @@ def from_pydantic(
cls,
output_class: ModelOrListOfModels,
*,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
num_reasks: Optional[int] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
reask_messages: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
tracer: Optional[Tracer] = None,
@@ -483,10 +472,7 @@ def from_pydantic(
Args:
output_class: (Union[Type[BaseModel], List[Type[BaseModel]]]): The pydantic model that describes
the desired structure of the output.
- prompt (str, optional): The prompt used to generate the string. Defaults to None.
- instructions (str, optional): Instructions for chat models. Defaults to None.
- reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
- reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
+ messages: (List[Dict], optional): A list of messages to send to the llm. Defaults to None.
reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None.
num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated
tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None.
@@ -506,31 +492,14 @@ def from_pydantic(
DeprecationWarning,
)
- if reask_instructions:
- warnings.warn(
- "reask_instructions is deprecated and will be removed in 0.6.x!"
- "Please be prepared to set reask_messages instead.",
- DeprecationWarning,
- )
- if reask_prompt:
- warnings.warn(
- "reask_prompt is deprecated and will be removed in 0.6.x!"
- "Please be prepared to set reask_messages instead.",
- DeprecationWarning,
- )
-
# We have to set the tracer in the ContextStore before the Rail,
# and therefore the Validators, are initialized
cls._set_tracer(cls, tracer) # type: ignore
schema = pydantic_model_to_schema(output_class)
exec_opts = GuardExecutionOptions(
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
- reask_messages=reask_messages,
messages=messages,
+ reask_messages=reask_messages,
)
guard = cls(
name=name,
@@ -566,10 +535,6 @@ def from_string(
validators: Sequence[Validator],
*,
string_description: Optional[str] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
reask_messages: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
num_reasks: Optional[int] = None,
@@ -582,28 +547,13 @@ def from_string(
Args:
validators: (List[Validator]): The list of validators to apply to the string output.
string_description (str, optional): A description for the string to be generated. Defaults to None.
- prompt (str, optional): The prompt used to generate the string. Defaults to None.
- instructions (str, optional): Instructions for chat models. Defaults to None.
- reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
- reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
+ messages: (List[Dict], optional): A list of messages to send to the llm. Defaults to None.
reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None.
num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated
tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None.
name (str, optional): A unique name for this Guard. Defaults to `gr-` + the object id.
description (str, optional): A description for this Guard. Defaults to None.
""" # noqa
- if reask_instructions:
- warnings.warn(
- "reask_instructions is deprecated and will be removed in 0.6.x!"
- "Please be prepared to set reask_messages instead.",
- DeprecationWarning,
- )
- if reask_prompt:
- warnings.warn(
- "reask_prompt is deprecated and will be removed in 0.6.x!"
- "Please be prepared to set reask_messages instead.",
- DeprecationWarning,
- )
if num_reasks:
warnings.warn(
@@ -623,10 +573,6 @@ def from_string(
list(validators), type=SimpleTypes.STRING, description=string_description
)
exec_opts = GuardExecutionOptions(
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
reask_messages=reask_messages,
messages=messages,
)
@@ -653,11 +599,8 @@ def _execute(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
+ messages: Optional[List[Dict]] = None,
+ reask_messages: Optional[List[Dict]] = None,
metadata: Optional[Dict],
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -666,17 +609,13 @@ def _execute(
self._fill_validators()
self._fill_exec_opts(
num_reasks=num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
+ messages=messages,
+ reask_messages=reask_messages,
)
metadata = metadata or {}
- if not llm_output and llm_api and not (prompt or msg_history):
- raise RuntimeError(
- "'prompt' or 'msg_history' must be provided in order to call an LLM!"
- )
+ print("==== _execute messages", messages)
+ if not (messages) and llm_api:
+ raise RuntimeError("'messages' must be provided in order to call an LLM!")
# check if validator requirements are fulfilled
missing_keys = verify_metadata_requirements(metadata, self._validators)
@@ -692,9 +631,7 @@ def __exec(
llm_output: Optional[str] = None,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -723,14 +660,6 @@ def __exec(
("guard_id", self.id),
("user_id", self._user_id),
("llm_api", llm_api_str if llm_api_str else "None"),
- (
- "custom_reask_prompt",
- self._exec_opts.reask_prompt is not None,
- ),
- (
- "custom_reask_instructions",
- self._exec_opts.reask_instructions is not None,
- ),
(
"custom_reask_messages",
self._exec_opts.reask_messages is not None,
@@ -751,13 +680,10 @@ def __exec(
"This should never happen."
)
- input_prompt = prompt or self._exec_opts.prompt
- input_instructions = instructions or self._exec_opts.instructions
+ input_messages = messages or self._exec_opts.messages
call_inputs = CallInputs(
llm_api=llm_api,
- prompt=input_prompt,
- instructions=input_instructions,
- msg_history=msg_history,
+ messages=input_messages,
prompt_params=prompt_params,
num_reasks=self._num_reasks,
metadata=metadata,
@@ -776,9 +702,7 @@ def __exec(
prompt_params=prompt_params,
metadata=metadata,
full_schema_reask=full_schema_reask,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
*args,
**kwargs,
)
@@ -787,14 +711,13 @@ def __exec(
set_scope(str(object_id(call_log)))
self.history.push(call_log)
# Otherwise, call the LLM synchronously
+ print("====executing messages", messages)
return self._exec(
llm_api=llm_api,
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=self._num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
call_log=call_log,
@@ -817,9 +740,7 @@ def __exec(
llm_output=llm_output,
prompt_params=prompt_params,
num_reasks=num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
*args,
@@ -836,9 +757,7 @@ def _exec(
num_reasks: int = 0, # Should be defined at this point
metadata: Dict, # Should be defined at this point
full_schema_reask: bool = False, # Should be defined at this point
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
**kwargs,
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
api = None
@@ -858,9 +777,7 @@ def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
api=api,
metadata=metadata,
output=llm_output,
@@ -877,9 +794,7 @@ def _exec(
output_schema=self.output_schema.to_dict(),
num_reasks=num_reasks,
validation_map=self._validator_map,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
api=api,
metadata=metadata,
output=llm_output,
@@ -897,9 +812,6 @@ def __call__(
*args,
prompt_params: Optional[Dict] = None,
num_reasks: Optional[int] = 1,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
full_schema_reask: Optional[bool] = None,
**kwargs,
@@ -911,9 +823,6 @@ def __call__(
(e.g. openai.completions.create or openai.Completion.acreate)
prompt_params: The parameters to pass to the prompt.format() method.
num_reasks: The max times to re-ask the LLM for invalid output.
- prompt: The prompt to use for the LLM.
- instructions: Instructions for chat models.
- msg_history: The message history to pass to the LLM.
metadata: Metadata to pass to the validators.
full_schema_reask: When reasking, whether to regenerate the full schema
or just the incorrect values.
@@ -923,24 +832,20 @@ def __call__(
Returns:
ValidationOutcome
"""
- instructions = instructions or self._exec_opts.instructions
- prompt = prompt or self._exec_opts.prompt
- msg_history = msg_history or kwargs.get("messages", None) or []
- if prompt is None:
- if msg_history is not None and not len(msg_history):
- raise RuntimeError(
- "You must provide a prompt if msg_history is empty. "
- "Alternatively, you can provide a prompt in the Schema constructor."
- )
-
+ messages = kwargs.pop("messages", None) or self._exec_opts.messages or []
+ print("==== call kwargs", kwargs)
+ print("==== call messages", messages)
+ if messages is not None and not len(messages):
+ raise RuntimeError(
+ "You must provide messages "
+ "Alternatively, you can provide a messages in the Schema constructor."
+ )
return self._execute(
*args,
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
@@ -981,14 +886,8 @@ def parse(
if llm_api is None
else 1
)
- default_prompt = self._exec_opts.prompt if llm_api else None
- prompt = kwargs.pop("prompt", default_prompt)
-
- default_instructions = self._exec_opts.instructions if llm_api else None
- instructions = kwargs.pop("instructions", default_instructions)
-
- default_msg_history = self._exec_opts.msg_history if llm_api else None
- msg_history = kwargs.pop("msg_history", default_msg_history)
+ default_messages = self._exec_opts.messages if llm_api else None
+ messages = kwargs.pop("messages", default_messages)
return self._execute( # type: ignore # streams are supported for parse
*args,
@@ -996,9 +895,7 @@ def parse(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=final_num_reasks,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
metadata=metadata,
full_schema_reask=full_schema_reask,
**kwargs,
@@ -1020,14 +917,12 @@ def error_spans_in_output(self) -> List[ErrorSpan]:
def __add_validator(self, validator: Validator, on: str = "output"):
if on not in [
"output",
- "prompt",
- "instructions",
- "msg_history",
+ "messages",
] and not on.startswith("$"):
warnings.warn(
f"Unusual 'on' value: {on}!"
"This value is typically one of "
- "'output', 'prompt', 'instructions', 'msg_history') "
+ "'output', 'messages') "
"or a JSON path starting with '$.'",
UserWarning,
)
@@ -1063,9 +958,7 @@ def use(
) -> "Guard":
"""Use a validator to validate either of the following:
- The output of an LLM request
- - The prompt
- - The instructions
- - The message history
+ - The messages
Args:
validator: The validator to use. Either the class or an instance.
@@ -1248,16 +1141,10 @@ def _call_server(
if llm_api is not None:
payload["llmApi"] = get_llm_api_enum(llm_api, *args, **kwargs)
- if not payload.get("prompt"):
- payload["prompt"] = self._exec_opts.prompt
- if not payload.get("instructions"):
- payload["instructions"] = self._exec_opts.instructions
- if not payload.get("msg_history"):
- payload["msg_history"] = self._exec_opts.msg_history
- if not payload.get("reask_prompt"):
- payload["reask_prompt"] = self._exec_opts.reask_prompt
- if not payload.get("reask_instructions"):
- payload["reask_instructions"] = self._exec_opts.reask_instructions
+ if not payload.get("messages"):
+ payload["messages"] = self._exec_opts.messages
+ if not payload.get("reask_messages"):
+ payload["reask_messages"] = self._exec_opts.reask_messages
should_stream = kwargs.get("stream", False)
if should_stream:
diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py
index 2eb3d8f26..c656259ab 100644
--- a/guardrails/llm_providers.py
+++ b/guardrails/llm_providers.py
@@ -309,170 +309,6 @@ def _invoke_llm(
)
-class CohereCallable(PromptCallableBase):
- def _invoke_llm(
- self, prompt: str, client_callable: Any, model: str, *args, **kwargs
- ) -> LLMResponse:
- """To use cohere for guardrails, do ``` client =
- cohere.Client(api_key=...)
-
- raw_llm_response, validated_response, *rest = guard(
- client.generate,
- prompt_params={...},
- model="command-nightly",
- ...
- )
- ```
- """ # noqa
- warnings.warn(
- "This callable is deprecated in favor of passing "
- "no callable and the model argument which utilizes LiteLLM"
- "for example guard(model='command-r', messages=[...], ...)",
- DeprecationWarning,
- )
-
- trace_input_messages = chat_prompt(prompt, kwargs.get("instructions"))
- if "instructions" in kwargs:
- prompt = kwargs.pop("instructions") + "\n\n" + prompt
-
- def is_base_cohere_chat(func):
- try:
- return (
- func.__closure__[1].cell_contents.__func__.__qualname__
- == "BaseCohere.chat"
- )
- except (AttributeError, IndexError):
- return False
-
- # TODO: When cohere totally gets rid of `generate`,
- # remove this cond and the final return
- if is_base_cohere_chat(client_callable):
- trace_operation(
- input_mime_type="application/json",
- input_value={**kwargs, "message": prompt, "args": args, "model": model},
- )
-
- trace_llm_call(
- input_messages=trace_input_messages,
- invocation_parameters={**kwargs, "message": prompt, "model": model},
- )
- cohere_response = client_callable(
- message=prompt, model=model, *args, **kwargs
- )
- trace_operation(
- output_mime_type="application/json", output_value=cohere_response
- )
- trace_llm_call(
- output_messages=[{"role": "assistant", "content": cohere_response.text}]
- )
- return LLMResponse(
- output=cohere_response.text,
- )
-
- trace_operation(
- input_mime_type="application/json",
- input_value={**kwargs, "prompt": prompt, "args": args, "model": model},
- )
-
- trace_llm_call(
- input_messages=trace_input_messages,
- invocation_parameters={**kwargs, "prompt": prompt, "model": model},
- )
- cohere_response = client_callable(prompt=prompt, model=model, *args, **kwargs)
- trace_operation(
- output_mime_type="application/json", output_value=cohere_response
- )
- trace_llm_call(
- output_messages=[{"role": "assistant", "content": cohere_response[0].text}]
- )
- return LLMResponse(
- output=cohere_response[0].text,
- )
-
-
-class AnthropicCallable(PromptCallableBase):
- def _invoke_llm(
- self,
- prompt: str,
- client_callable: Any,
- model: str = "claude-instant-1",
- max_tokens_to_sample: int = 100,
- *args,
- **kwargs,
- ) -> LLMResponse:
- """Wrapper for Anthropic Completions.
-
- To use Anthropic for guardrails, do
- ```
- client = anthropic.Anthropic(api_key=...)
-
- raw_llm_response, validated_response = guard(
- client,
- model="claude-2",
- max_tokens_to_sample=200,
- prompt_params={...},
- ...
- ```
- """
- warnings.warn(
- "This callable is deprecated in favor of passing "
- "no callable and the model argument which utilizes LiteLLM"
- "for example guard(model='claude-3-opus-20240229', messages=[...], ...)",
- DeprecationWarning,
- )
- try:
- import anthropic
- except ImportError:
- raise PromptCallableException(
- "The `anthropic` package is not installed. "
- "Install with `pip install anthropic`"
- )
-
- trace_input_messages = chat_prompt(prompt, kwargs.get("instructions"))
- if "instructions" in kwargs:
- prompt = kwargs.pop("instructions") + "\n\n" + prompt
-
- anthropic_prompt = f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}"
-
- trace_operation(
- input_mime_type="application/json",
- input_value={
- **kwargs,
- "model": model,
- "prompt": anthropic_prompt,
- "max_tokens_to_sample": max_tokens_to_sample,
- "args": args,
- },
- )
-
- trace_llm_call(
- input_messages=trace_input_messages,
- invocation_parameters={
- **kwargs,
- "model": model,
- "prompt": anthropic_prompt,
- "max_tokens_to_sample": max_tokens_to_sample,
- },
- )
-
- anthropic_response = client_callable(
- model=model,
- prompt=anthropic_prompt,
- max_tokens_to_sample=max_tokens_to_sample,
- *args,
- **kwargs,
- )
- trace_operation(
- output_mime_type="application/json", output_value=anthropic_response
- )
- trace_llm_call(
- output_messages=[
- {"role": "assistant", "content": anthropic_response.completion}
- ]
- )
- return LLMResponse(output=anthropic_response.completion)
-
-
class LiteLLMCallable(PromptCallableBase):
def _invoke_llm(
self,
@@ -815,7 +651,6 @@ def get_llm_ask(
) -> Optional[PromptCallableBase]:
if "temperature" not in kwargs:
kwargs.update({"temperature": 0})
-
try:
from litellm import completion
@@ -824,11 +659,6 @@ def get_llm_ask(
except ImportError:
pass
- if llm_api == get_static_openai_create_func():
- return OpenAICallable(*args, **kwargs)
- if llm_api == get_static_openai_chat_create_func():
- return OpenAIChatCallable(*args, **kwargs)
-
try:
import manifest # noqa: F401 # type: ignore
@@ -837,28 +667,6 @@ def get_llm_ask(
except ImportError:
pass
- try:
- import cohere # noqa: F401 # type: ignore
-
- if (
- isinstance(getattr(llm_api, "__self__", None), cohere.Client)
- and getattr(llm_api, "__name__", None) == "generate"
- ) or getattr(llm_api, "__module__", None) == "cohere.client":
- return CohereCallable(*args, client_callable=llm_api, **kwargs)
- except ImportError:
- pass
-
- try:
- import anthropic.resources # noqa: F401 # type: ignore
-
- if isinstance(
- getattr(llm_api, "__self__", None),
- anthropic.resources.completions.Completions,
- ):
- return AnthropicCallable(*args, client_callable=llm_api, **kwargs)
- except ImportError:
- pass
-
try:
from transformers import ( # noqa: F401 # type: ignore
FlaxPreTrainedModel,
diff --git a/guardrails/prompt/__init__.py b/guardrails/prompt/__init__.py
index 15df888b1..1a4e2a655 100644
--- a/guardrails/prompt/__init__.py
+++ b/guardrails/prompt/__init__.py
@@ -1,7 +1,9 @@
from .instructions import Instructions
from .prompt import Prompt
+from .messages import Messages
__all__ = [
"Prompt",
"Instructions",
+ "Messages"
]
diff --git a/guardrails/prompt/messages.py b/guardrails/prompt/messages.py
index a7adec457..81e7328d0 100644
--- a/guardrails/prompt/messages.py
+++ b/guardrails/prompt/messages.py
@@ -8,6 +8,7 @@
from guardrails.classes.templating.namespace_template import NamespaceTemplate
from guardrails.utils.constants import constants
from guardrails.utils.templating_utils import get_template_variables
+from guardrails.prompt.prompt import Prompt
class Messages:
@@ -36,8 +37,12 @@ def __init__(
message["content"] = Template(message["content"]).safe_substitute(
output_schema=output_schema, xml_output_schema=xml_output_schema
)
- else:
- self.source = source
+
+ self.source = self._source
+
+ @property
+ def variable_names(self):
+ return get_template_variables(messages_string(self))
def format(
self,
@@ -59,6 +64,9 @@ def format(
)
return Messages(formatted_messages)
+ def __iter__(self):
+ return iter(self._source)
+
def substitute_constants(self, text):
"""Substitute constants in the prompt."""
# Substitute constants by reading the constants file.
@@ -73,3 +81,15 @@ def substitute_constants(self, text):
text = template.safe_substitute(**mapping)
return text
+
+
+def messages_string(messages: Messages) -> str:
+ messages_copy = ""
+ for msg in messages:
+ content = (
+ msg["content"].source
+ if getattr(msg, "content", None) and isinstance(msg["content"], Prompt)
+ else msg["content"]
+ )
+ messages_copy += content
+ return messages_copy
diff --git a/guardrails/run/async_runner.py b/guardrails/run/async_runner.py
index f10f48946..909ff0974 100644
--- a/guardrails/run/async_runner.py
+++ b/guardrails/run/async_runner.py
@@ -7,19 +7,16 @@
from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
from guardrails.classes.output_type import OutputTypes
-from guardrails.constants import fail_status
from guardrails.errors import ValidationError
from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase
from guardrails.logger import set_scope
-from guardrails.prompt import Instructions, Prompt
from guardrails.run.runner import Runner
-from guardrails.run.utils import msg_history_source, msg_history_string
+from guardrails.run.utils import msg_history_string
from guardrails.schema.validator import schema_validation
from guardrails.types.pydantic import ModelOrListOfModels
from guardrails.types.validator import ValidatorMap
from guardrails.utils.exception_utils import UserFacingException
from guardrails.classes.llm.llm_response import LLMResponse
-from guardrails.utils.prompt_utils import preprocess_prompt, prompt_uses_xml
from guardrails.actions.reask import NonParseableReAsk, ReAsk
from guardrails.utils.telemetry_utils import async_trace
@@ -32,9 +29,7 @@ def __init__(
num_reasks: int,
validation_map: ValidatorMap,
*,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
api: Optional[AsyncPromptCallableBase] = None,
metadata: Optional[Dict[str, Any]] = None,
output: Optional[str] = None,
@@ -48,9 +43,7 @@ def __init__(
output_schema=output_schema,
num_reasks=num_reasks,
validation_map=validation_map,
- prompt=prompt,
- instructions=instructions,
- msg_history=msg_history,
+ messages=messages,
api=api,
metadata=metadata,
output=output,
@@ -78,20 +71,11 @@ async def async_run(
"""
prompt_params = prompt_params or {}
try:
- # Figure out if we need to include instructions in the prompt.
- include_instructions = not (
- self.instructions is None and self.msg_history is None
- )
-
(
- instructions,
- prompt,
- msg_history,
+ messages,
output_schema,
) = (
- self.instructions,
- self.prompt,
- self.msg_history,
+ self.messages,
self.output_schema,
)
index = 0
@@ -100,9 +84,7 @@ async def async_run(
iteration = await self.async_step(
index=index,
api=self.api,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
output_schema=output_schema,
output=self.output if index == 0 else None,
@@ -115,17 +97,14 @@ async def async_run(
# Get new prompt and output schema.
(
- prompt,
- instructions,
+ messages,
output_schema,
- msg_history,
) = self.prepare_to_loop(
iteration.reasks,
output_schema,
parsed_output=iteration.outputs.parsed_output,
validated_output=call_log.validation_response,
prompt_params=prompt_params,
- include_instructions=include_instructions,
)
# Log how many times we reasked
@@ -157,9 +136,7 @@ async def async_step(
call_log: Call,
*,
api: Optional[AsyncPromptCallableBase],
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
prompt_params: Optional[Dict] = None,
output: Optional[str] = None,
) -> Iteration:
@@ -168,9 +145,7 @@ async def async_step(
inputs = Inputs(
llm_api=api,
llm_output=output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
metadata=self.metadata,
@@ -185,29 +160,21 @@ async def async_step(
try:
# Prepare: run pre-processing, and input validation.
- if output is not None:
- instructions = None
- prompt = None
- msg_history = None
+ if output:
+ messages = None
else:
- instructions, prompt, msg_history = await self.async_prepare(
+ messages = await self.async_prepare(
call_log,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
api=api,
attempt_number=index,
)
- iteration.inputs.instructions = instructions
- iteration.inputs.prompt = prompt
- iteration.inputs.msg_history = msg_history
+ iteration.inputs.messages = messages
# Call: run the API.
- llm_response = await self.async_call(
- instructions, prompt, msg_history, api, output
- )
+ llm_response = await self.async_call(messages, api, output)
iteration.outputs.llm_response_info = llm_response
output = llm_response.output
@@ -249,9 +216,7 @@ async def async_step(
@async_trace(name="call")
async def async_call(
self,
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]],
+ messages: Optional[List[Dict]],
api: Optional[AsyncPromptCallableBase],
output: Optional[str] = None,
) -> LLMResponse:
@@ -273,12 +238,8 @@ async def async_call(
)
elif api_fn is None:
raise ValueError("API or output must be provided.")
- elif msg_history:
- llm_response = await api_fn(msg_history=msg_history_source(msg_history))
- elif prompt and instructions:
- llm_response = await api_fn(prompt.source, instructions=instructions.source)
- elif prompt:
- llm_response = await api_fn(prompt.source)
+ elif messages:
+ llm_response = await api_fn(messages=messages.source)
else:
llm_response = await api_fn()
return llm_response
@@ -327,45 +288,30 @@ async def async_prepare(
call_log: Call,
attempt_number: int,
*,
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]],
+ messages: Optional[List[Dict]],
prompt_params: Optional[Dict] = None,
api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]],
- ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]:
+ ) -> Tuple[Optional[List[Dict]]]:
"""Prepare by running pre-processing and input validation.
Returns:
- The instructions, prompt, and message history.
+ The messages
"""
prompt_params = prompt_params or {}
- has_prompt_validation = "prompt" in self.validation_map
- has_instructions_validation = "instructions" in self.validation_map
- has_msg_history_validation = "msg_history" in self.validation_map
- if msg_history:
- if has_prompt_validation or has_instructions_validation:
- raise UserFacingException(
- ValueError(
- "Prompt and instructions validation are "
- "not supported when using message history."
- )
- )
-
- prompt, instructions = None, None
-
+ if messages:
# Runner.prepare_msg_history
- formatted_msg_history = []
+ formatted_messages = []
# Format any variables in the message history with the prompt params.
- for msg in msg_history:
+ for msg in messages:
msg_copy = copy.deepcopy(msg)
msg_copy["content"] = msg_copy["content"].format(**prompt_params)
- formatted_msg_history.append(msg_copy)
+ formatted_messages.append(msg_copy)
- if "msg_history" in self.validation_map:
- # Runner.validate_msg_history
- msg_str = msg_history_string(formatted_msg_history)
+ if "messages" in self.validation_map:
+ # Runner.validate_message
+ msg_str = msg_history_string(formatted_messages)
inputs = Inputs(
llm_output=msg_str,
)
@@ -379,112 +325,21 @@ async def async_prepare(
validator_map=self.validation_map,
iteration=iteration,
disable_tracer=self._disable_tracer,
- path="msg_history",
+ path="message",
)
- validated_msg_history = validator_service.post_process_validation(
+ validated_messages = validator_service.post_process_validation(
value, attempt_number, iteration, OutputTypes.STRING
)
- validated_msg_history = cast(str, validated_msg_history)
+ validated_messages = cast(str, validated_messages)
- iteration.outputs.validation_response = validated_msg_history
- if isinstance(validated_msg_history, ReAsk):
+ iteration.outputs.validation_response = validated_messages
+ if isinstance(validated_messages, ReAsk):
raise ValidationError(
- f"Message history validation failed: "
- f"{validated_msg_history}"
+ f"Messages validation failed: " f"{validated_messages}"
)
- if validated_msg_history != msg_str:
+ if validated_messages != msg_str:
raise ValidationError("Message history validation failed")
- elif prompt is not None:
- if has_msg_history_validation:
- raise UserFacingException(
- ValueError(
- "Message history validation is "
- "not supported when using prompt/instructions."
- )
- )
- msg_history = None
-
- use_xml = prompt_uses_xml(prompt._source)
- # Runner.prepare_prompt
- prompt = prompt.format(**prompt_params)
-
- # TODO(shreya): should there be any difference
- # to parsing params for prompt?
- if instructions is not None and isinstance(instructions, Instructions):
- instructions = instructions.format(**prompt_params)
-
- instructions, prompt = preprocess_prompt(
- prompt_callable=api, # type: ignore
- instructions=instructions,
- prompt=prompt,
- output_type=self.output_type,
- use_xml=use_xml,
- )
-
- # validate prompt
- if "prompt" in self.validation_map and prompt is not None:
- # Runner.validate_prompt
- inputs = Inputs(
- llm_output=prompt.source,
- )
- iteration = Iteration(
- call_id=call_log.id, index=attempt_number, inputs=inputs
- )
- call_log.iterations.insert(0, iteration)
- value, _metadata = await validator_service.async_validate(
- value=prompt.source,
- metadata=self.metadata,
- validator_map=self.validation_map,
- iteration=iteration,
- disable_tracer=self._disable_tracer,
- path="prompt",
- )
- validated_prompt = validator_service.post_process_validation(
- value, attempt_number, iteration, OutputTypes.STRING
- )
-
- iteration.outputs.validation_response = validated_prompt
- if isinstance(validated_prompt, ReAsk):
- raise ValidationError(
- f"Prompt validation failed: {validated_prompt}"
- )
- elif not validated_prompt or iteration.status == fail_status:
- raise ValidationError("Prompt validation failed")
- prompt = Prompt(cast(str, validated_prompt))
-
- # validate instructions
- if "instructions" in self.validation_map and instructions is not None:
- # Runner.validate_instructions
- inputs = Inputs(
- llm_output=instructions.source,
- )
- iteration = Iteration(
- call_id=call_log.id, index=attempt_number, inputs=inputs
- )
- call_log.iterations.insert(0, iteration)
- value, _metadata = await validator_service.async_validate(
- value=instructions.source,
- metadata=self.metadata,
- validator_map=self.validation_map,
- iteration=iteration,
- disable_tracer=self._disable_tracer,
- path="instructions",
- )
- validated_instructions = validator_service.post_process_validation(
- value, attempt_number, iteration, OutputTypes.STRING
- )
-
- iteration.outputs.validation_response = validated_instructions
- if isinstance(validated_instructions, ReAsk):
- raise ValidationError(
- f"Instructions validation failed: {validated_instructions}"
- )
- elif not validated_instructions or iteration.status == fail_status:
- raise ValidationError("Instructions validation failed")
- instructions = Instructions(cast(str, validated_instructions))
else:
- raise UserFacingException(
- ValueError("'prompt' or 'msg_history' must be provided.")
- )
+ raise UserFacingException(ValueError("'messages' must be provided."))
- return instructions, prompt, msg_history
+ return messages
diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py
index 5f6206766..d20e6967a 100644
--- a/guardrails/run/async_stream_runner.py
+++ b/guardrails/run/async_stream_runner.py
@@ -23,7 +23,6 @@
PromptCallableBase,
)
from guardrails.logger import set_scope
-from guardrails.prompt import Instructions, Prompt
from guardrails.run import StreamRunner
from guardrails.run.async_runner import AsyncRunner
@@ -35,14 +34,10 @@ async def async_run(
prompt_params = prompt_params or {}
(
- instructions,
- prompt,
- msg_history,
+ messages,
output_schema,
) = (
- self.instructions,
- self.prompt,
- self.msg_history,
+ self.messages,
self.output_schema,
)
@@ -51,9 +46,7 @@ async def async_run(
output_schema,
call_log,
api=self.api,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
output=self.output,
)
@@ -70,9 +63,7 @@ async def async_step(
call_log: Call,
*,
api: Optional[AsyncPromptCallableBase],
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
prompt_params: Optional[Dict] = None,
output: Optional[str] = None,
) -> AsyncIterable[ValidationOutcome]:
@@ -80,9 +71,7 @@ async def async_step(
inputs = Inputs(
llm_api=api,
llm_output=output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
metadata=self.metadata,
@@ -95,28 +84,20 @@ async def async_step(
)
set_scope(str(id(iteration)))
call_log.iterations.push(iteration)
- if output is not None:
- instructions = None
- prompt = None
- msg_history = None
+ if output:
+ messages = None
else:
- instructions, prompt, msg_history = await self.async_prepare(
+ messages = await self.async_prepare(
call_log,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
api=api,
attempt_number=index,
)
- iteration.inputs.prompt = prompt
- iteration.inputs.instructions = instructions
- iteration.inputs.msg_history = msg_history
+ iteration.inputs.messages = messages
- llm_response = await self.async_call(
- instructions, prompt, msg_history, api, output
- )
+ llm_response = await self.async_call(messages, api, output)
iteration.outputs.llm_response_info = llm_response
stream_output = llm_response.async_stream_output
if not stream_output:
diff --git a/guardrails/run/runner.py b/guardrails/run/runner.py
index caab80ea4..4dc6c6f4e 100644
--- a/guardrails/run/runner.py
+++ b/guardrails/run/runner.py
@@ -1,6 +1,6 @@
import copy
from functools import partial
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from guardrails import validator_service
@@ -8,15 +8,15 @@
from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
from guardrails.classes.output_type import OutputTypes
-from guardrails.constants import fail_status
from guardrails.errors import ValidationError
from guardrails.llm_providers import (
AsyncPromptCallableBase,
PromptCallableBase,
)
from guardrails.logger import set_scope
-from guardrails.prompt import Instructions, Prompt
-from guardrails.run.utils import msg_history_source, msg_history_string
+from guardrails.prompt import Prompt
+from guardrails.prompt.messages import Messages
+from guardrails.run.utils import messages_string
from guardrails.schema.rail_schema import json_schema_to_rail_output
from guardrails.schema.validator import schema_validation
from guardrails.types import ModelOrListOfModels, ValidatorMap, MessageHistory
@@ -29,9 +29,7 @@
prune_extra_keys,
)
from guardrails.utils.prompt_utils import (
- preprocess_prompt,
prompt_content_for_schema,
- prompt_uses_xml,
)
from guardrails.actions.reask import NonParseableReAsk, ReAsk, introspect
from guardrails.utils.telemetry_utils import trace
@@ -61,9 +59,7 @@ class Runner:
metadata: Dict[str, Any]
# LLM Inputs
- prompt: Optional[Prompt] = None
- instructions: Optional[Instructions] = None
- msg_history: Optional[List[Dict[str, Union[Prompt, str]]]] = None
+ messages: Optional[List[Dict[str, Union[Prompt, str]]]] = None
base_model: Optional[ModelOrListOfModels]
exec_options: Optional[GuardExecutionOptions]
@@ -77,7 +73,7 @@ class Runner:
disable_tracer: Optional[bool] = True
# QUESTION: Are any of these init args actually necessary for initialization?
- # ANSWER: _Maybe_ prompt, instructions, and msg_history for Prompt initialization
+ # ANSWER: _Maybe_ messages for Prompt initialization
# but even that can happen at execution time.
# TODO: In versions >=0.6.x, remove this class and just execute a Guard functionally
def __init__(
@@ -87,9 +83,7 @@ def __init__(
num_reasks: int,
validation_map: ValidatorMap,
*,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
api: Optional[PromptCallableBase] = None,
metadata: Optional[Dict[str, Any]] = None,
output: Optional[str] = None,
@@ -113,34 +107,19 @@ def __init__(
xml_output_schema = json_schema_to_rail_output(
json_schema=output_schema, validator_map=validation_map
)
- if prompt:
- self.exec_options.prompt = prompt
- self.prompt = Prompt(
- prompt,
- output_schema=stringified_output_schema,
- xml_output_schema=xml_output_schema,
- )
-
- if instructions:
- self.exec_options.instructions = instructions
- self.instructions = Instructions(
- instructions,
- output_schema=stringified_output_schema,
- xml_output_schema=xml_output_schema,
- )
- if msg_history:
- self.exec_options.msg_history = msg_history
- msg_history_copy = []
- for msg in msg_history:
+ if messages:
+ self.exec_options.messages = messages
+ messages_copy = []
+ for msg in messages:
msg_copy = copy.deepcopy(msg)
msg_copy["content"] = Prompt(
msg_copy["content"],
output_schema=stringified_output_schema,
xml_output_schema=xml_output_schema,
)
- msg_history_copy.append(msg_copy)
- self.msg_history = msg_history_copy
+ messages_copy.append(msg_copy)
+ self.messages = Messages(source=messages_copy)
self.base_model = base_model
@@ -171,35 +150,24 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call
"""
prompt_params = prompt_params or {}
try:
- # Figure out if we need to include instructions in the prompt.
- include_instructions = not (
- self.instructions is None and self.msg_history is None
- )
-
# NOTE: At first glance this seems gratuitous,
# but these local variables are reassigned after
# calling self.prepare_to_loop
(
- instructions,
- prompt,
- msg_history,
+ messages,
output_schema,
) = (
- self.instructions,
- self.prompt,
- self.msg_history,
+ self.messages,
self.output_schema,
)
-
+ print("===runner self messages", self.messages)
index = 0
for index in range(self.num_reasks + 1):
# Run a single step.
iteration = self.step(
index=index,
api=self.api,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
output_schema=output_schema,
output=self.output if index == 0 else None,
@@ -210,16 +178,13 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call
if not self.do_loop(index, iteration.reasks):
break
- # Get new prompt and output schema.
- (prompt, instructions, output_schema, msg_history) = (
- self.prepare_to_loop(
- iteration.reasks,
- output_schema,
- parsed_output=iteration.outputs.parsed_output,
- validated_output=call_log.validation_response,
- prompt_params=prompt_params,
- include_instructions=include_instructions,
- )
+ # Get new messages and output schema.
+ (messages, output_schema) = self.prepare_to_loop(
+ iteration.reasks,
+ output_schema,
+ parsed_output=iteration.outputs.parsed_output,
+ validated_output=call_log.validation_response,
+ prompt_params=prompt_params,
)
# Log how many times we reasked
@@ -231,7 +196,6 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call
is_parent=False, # This span has no children
has_parent=True, # This span has a parent
)
-
except UserFacingException as e:
# Because Pydantic v1 doesn't respect property setters
call_log.exception = e.original_exception
@@ -250,20 +214,17 @@ def step(
call_log: Call,
*,
api: Optional[PromptCallableBase],
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict]] = None,
prompt_params: Optional[Dict] = None,
output: Optional[str] = None,
) -> Iteration:
"""Run a full step."""
+ print("==== step input messages", messages)
prompt_params = prompt_params or {}
inputs = Inputs(
llm_api=api,
llm_output=output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
metadata=self.metadata,
@@ -278,27 +239,21 @@ def step(
try:
# Prepare: run pre-processing, and input validation.
- if output is not None:
- instructions = None
- prompt = None
- msg_history = None
+ if output:
+ messages = None
else:
- instructions, prompt, msg_history = self.prepare(
+ messages = self.prepare(
call_log,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
api=api,
attempt_number=index,
)
- iteration.inputs.instructions = instructions
- iteration.inputs.prompt = prompt
- iteration.inputs.msg_history = msg_history
+ iteration.inputs.messages = messages
# Call: run the API.
- llm_response = self.call(instructions, prompt, msg_history, api, output)
+ llm_response = self.call(messages, api, output)
iteration.outputs.llm_response_info = llm_response
raw_output = llm_response.output
@@ -335,10 +290,10 @@ def step(
raise e
return iteration
- def validate_msg_history(
- self, call_log: Call, msg_history: MessageHistory, attempt_number: int
+ def validate_messages(
+ self, call_log: Call, messages: MessageHistory, attempt_number: int
) -> None:
- msg_str = msg_history_string(msg_history)
+ msg_str = messages_string(messages)
inputs = Inputs(
llm_output=msg_str,
)
@@ -350,189 +305,69 @@ def validate_msg_history(
validator_map=self.validation_map,
iteration=iteration,
disable_tracer=self._disable_tracer,
- path="msg_history",
+ path="messages",
)
- validated_msg_history = validator_service.post_process_validation(
+ validated_messages = validator_service.post_process_validation(
value, attempt_number, iteration, OutputTypes.STRING
)
- iteration.outputs.validation_response = validated_msg_history
- if isinstance(validated_msg_history, ReAsk):
+ iteration.outputs.validation_response = validated_messages
+ if isinstance(validated_messages, ReAsk):
raise ValidationError(
- f"Message history validation failed: " f"{validated_msg_history}"
+ f"Message validation failed: " f"{validated_messages}"
)
- if validated_msg_history != msg_str:
- raise ValidationError("Message history validation failed")
+ if validated_messages != msg_str:
+ raise ValidationError("Message validation failed")
- def prepare_msg_history(
+ def prepare_messages(
self,
call_log: Call,
- msg_history: MessageHistory,
+ messages: MessageHistory,
prompt_params: Dict,
attempt_number: int,
) -> MessageHistory:
- formatted_msg_history: MessageHistory = []
+ formatted_messages: MessageHistory = []
# Format any variables in the message history with the prompt params.
- for msg in msg_history:
+ for msg in messages:
msg_copy = copy.deepcopy(msg)
msg_copy["content"] = msg_copy["content"].format(**prompt_params)
- formatted_msg_history.append(msg_copy)
+ formatted_messages.append(msg_copy)
# validate msg_history
- if "msg_history" in self.validation_map:
- self.validate_msg_history(call_log, formatted_msg_history, attempt_number)
-
- return formatted_msg_history
-
- def validate_prompt(self, call_log: Call, prompt: Prompt, attempt_number: int):
- inputs = Inputs(
- llm_output=prompt.source,
- )
- iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
- call_log.iterations.insert(0, iteration)
- value, _metadata = validator_service.validate(
- value=prompt.source,
- metadata=self.metadata,
- validator_map=self.validation_map,
- iteration=iteration,
- disable_tracer=self._disable_tracer,
- path="prompt",
- )
-
- validated_prompt = validator_service.post_process_validation(
- value, attempt_number, iteration, OutputTypes.STRING
- )
-
- iteration.outputs.validation_response = validated_prompt
-
- if isinstance(validated_prompt, ReAsk):
- raise ValidationError(f"Prompt validation failed: {validated_prompt}")
- elif not validated_prompt or iteration.status == fail_status:
- raise ValidationError("Prompt validation failed")
- return Prompt(cast(str, validated_prompt))
-
- def validate_instructions(
- self, call_log: Call, instructions: Instructions, attempt_number: int
- ):
- inputs = Inputs(
- llm_output=instructions.source,
- )
- iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
- call_log.iterations.insert(0, iteration)
- value, _metadata = validator_service.validate(
- value=instructions.source,
- metadata=self.metadata,
- validator_map=self.validation_map,
- iteration=iteration,
- disable_tracer=self._disable_tracer,
- path="instructions",
- )
- validated_instructions = validator_service.post_process_validation(
- value, attempt_number, iteration, OutputTypes.STRING
- )
+ if "messages" in self.validation_map:
+ self.validate_messages(call_log, formatted_messages, attempt_number)
- iteration.outputs.validation_response = validated_instructions
- if isinstance(validated_instructions, ReAsk):
- raise ValidationError(
- f"Instructions validation failed: {validated_instructions}"
- )
- elif not validated_instructions or iteration.status == fail_status:
- raise ValidationError("Instructions validation failed")
- return Instructions(cast(str, validated_instructions))
-
- def prepare_prompt(
- self,
- call_log: Call,
- instructions: Optional[Instructions],
- prompt: Prompt,
- prompt_params: Dict,
- api: Union[PromptCallableBase, AsyncPromptCallableBase],
- attempt_number: int,
- ):
- use_xml = prompt_uses_xml(self.prompt._source) if self.prompt else False
- prompt = prompt.format(**prompt_params)
-
- # TODO(shreya): should there be any difference
- # to parsing params for prompt?
- if instructions is not None and isinstance(instructions, Instructions):
- instructions = instructions.format(**prompt_params)
-
- instructions, prompt = preprocess_prompt(
- prompt_callable=api,
- instructions=instructions,
- prompt=prompt,
- output_type=self.output_type,
- use_xml=use_xml,
- )
-
- # validate prompt
- if "prompt" in self.validation_map and prompt is not None:
- prompt = self.validate_prompt(call_log, prompt, attempt_number)
-
- # validate instructions
- if "instructions" in self.validation_map and instructions is not None:
- instructions = self.validate_instructions(
- call_log, instructions, attempt_number
- )
-
- return instructions, prompt
+ return formatted_messages
def prepare(
self,
call_log: Call,
attempt_number: int,
*,
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[MessageHistory],
+ messages: Optional[MessageHistory],
prompt_params: Optional[Dict] = None,
api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]],
- ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[MessageHistory]]:
+ ) -> Tuple[Optional[MessageHistory]]:
"""Prepare by running pre-processing and input validation.
Returns:
- The instructions, prompt, and message history.
+ The messages.
"""
prompt_params = prompt_params or {}
if api is None:
raise UserFacingException(ValueError("API must be provided."))
- has_prompt_validation = "prompt" in self.validation_map
- has_instructions_validation = "instructions" in self.validation_map
- has_msg_history_validation = "msg_history" in self.validation_map
- if msg_history:
- if has_prompt_validation or has_instructions_validation:
- raise UserFacingException(
- ValueError(
- "Prompt and instructions validation are "
- "not supported when using message history."
- )
- )
- prompt, instructions = None, None
- msg_history = self.prepare_msg_history(
- call_log, msg_history, prompt_params, attempt_number
- )
- elif prompt is not None:
- if has_msg_history_validation:
- raise UserFacingException(
- ValueError(
- "Message history validation is "
- "not supported when using prompt/instructions."
- )
- )
- msg_history = None
- instructions, prompt = self.prepare_prompt(
- call_log, instructions, prompt, prompt_params, api, attempt_number
+ if messages:
+ messages = self.prepare_messages(
+ call_log, messages, prompt_params, attempt_number
)
- return instructions, prompt, msg_history
+ return messages
@trace(name="call")
def call(
self,
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[MessageHistory],
+ messages: Optional[MessageHistory],
api: Optional[PromptCallableBase],
output: Optional[str] = None,
) -> LLMResponse:
@@ -542,7 +377,7 @@ def call(
2. Convert the response string to a dict,
3. Log the output
"""
-
+ print("====messages", messages)
# If the API supports a base model, pass it in.
api_fn = api
if api is not None:
@@ -554,12 +389,8 @@ def call(
llm_response = LLMResponse(output=output)
elif api_fn is None:
raise ValueError("API or output must be provided.")
- elif msg_history:
- llm_response = api_fn(msg_history=msg_history_source(msg_history))
- elif prompt and instructions:
- llm_response = api_fn(prompt.source, instructions=instructions.source)
- elif prompt:
- llm_response = api_fn(prompt.source)
+ elif messages:
+ llm_response = api_fn(messages=messages.source)
else:
llm_response = api_fn()
@@ -635,16 +466,10 @@ def prepare_to_loop(
parsed_output: Optional[Union[str, List, Dict, ReAsk]] = None,
validated_output: Optional[Union[str, List, Dict, ReAsk]] = None,
prompt_params: Optional[Dict] = None,
- include_instructions: bool = False,
- ) -> Tuple[
- Prompt,
- Optional[Instructions],
- Dict[str, Any],
- Optional[List[Dict]],
- ]:
+ ) -> Tuple[Dict[str, Any], Optional[List[Dict]]]:
"""Prepare to loop again."""
prompt_params = prompt_params or {}
- output_schema, prompt, instructions = get_reask_setup(
+ output_schema, messages = get_reask_setup(
output_type=self.output_type,
output_schema=output_schema,
validation_map=self.validation_map,
@@ -655,8 +480,5 @@ def prepare_to_loop(
prompt_params=prompt_params,
exec_options=self.exec_options,
)
- if not include_instructions:
- instructions = None
- # todo add messages support
- msg_history = None
- return prompt, instructions, output_schema, msg_history
+
+ return output_schema, messages
diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py
index d53f3dc2e..aa412a458 100644
--- a/guardrails/run/stream_runner.py
+++ b/guardrails/run/stream_runner.py
@@ -9,7 +9,6 @@
OpenAIChatCallable,
PromptCallableBase,
)
-from guardrails.prompt import Instructions, Prompt
from guardrails.run.runner import Runner
from guardrails.utils.parsing_utils import (
coerce_types,
@@ -41,32 +40,20 @@ def __call__(
Returns:
The Call log for this run.
"""
- # This is only used during ReAsks and ReAsks
- # are not yet supported for streaming.
- # Figure out if we need to include instructions in the prompt.
- # include_instructions = not (
- # self.instructions is None and self.msg_history is None
- # )
prompt_params = prompt_params or {}
(
- instructions,
- prompt,
- msg_history,
+ messages,
output_schema,
) = (
- self.instructions,
- self.prompt,
- self.msg_history,
+ self.messages,
self.output_schema,
)
return self.step(
index=0,
api=self.api,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
output_schema=output_schema,
output=self.output,
@@ -78,9 +65,7 @@ def step(
self,
index: int,
api: Optional[PromptCallableBase],
- instructions: Optional[Instructions],
- prompt: Optional[Prompt],
- msg_history: Optional[List[Dict]],
+ messages: Optional[List[Dict]],
prompt_params: Dict,
output_schema: Dict[str, Any],
call_log: Call,
@@ -90,9 +75,7 @@ def step(
inputs = Inputs(
llm_api=api,
llm_output=output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
metadata=self.metadata,
@@ -106,27 +89,21 @@ def step(
call_log.iterations.push(iteration)
# Prepare: run pre-processing, and input validation.
- if output is not None:
- instructions = None
- prompt = None
- msg_history = None
+ if output:
+ messages = None
else:
- instructions, prompt, msg_history = self.prepare(
+ messages = self.prepare(
call_log,
index,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
api=api,
)
- iteration.inputs.prompt = prompt
- iteration.inputs.instructions = instructions
- iteration.inputs.msg_history = msg_history
+ iteration.inputs.messages = messages
# Call: run the API that returns a generator wrapped in LLMResponse
- llm_response = self.call(instructions, prompt, msg_history, api, output)
+ llm_response = self.call(messages, api, output)
iteration.outputs.llm_response_info = llm_response
diff --git a/guardrails/run/utils.py b/guardrails/run/utils.py
index cdd9de4d2..e6979b4c6 100644
--- a/guardrails/run/utils.py
+++ b/guardrails/run/utils.py
@@ -2,6 +2,7 @@
from typing import Dict, cast
from guardrails.prompt.prompt import Prompt
+from guardrails.prompt.messages import Messages
from guardrails.types.inputs import MessageHistory
@@ -29,3 +30,15 @@ def msg_history_string(msg_history: MessageHistory) -> str:
)
msg_history_copy += content
return msg_history_copy
+
+
+def messages_string(messages: Messages) -> str:
+ messages_copy = ""
+ for msg in messages:
+ content = (
+ msg["content"].source
+ if isinstance(msg["content"], Prompt)
+ else msg["content"]
+ )
+ messages_copy += content
+ return messages_copy
diff --git a/guardrails/schema/rail_schema.py b/guardrails/schema/rail_schema.py
index 2aaa04ec4..ab8edbbe7 100644
--- a/guardrails/schema/rail_schema.py
+++ b/guardrails/schema/rail_schema.py
@@ -1,5 +1,4 @@
import jsonref
-import warnings
from dataclasses import dataclass
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
@@ -21,14 +20,7 @@
### RAIL to JSON Schema ###
-STRING_TAGS = [
- "instructions",
- "prompt",
- "reask_instructions",
- "reask_prompt",
- "messages",
- "reask_messages",
-]
+STRING_TAGS = ["messages", "reask_messages"]
def parse_on_fail_handlers(element: _Element) -> Dict[str, OnFailAction]:
@@ -392,49 +384,6 @@ def rail_string_to_schema(rail_string: str) -> ProcessedSchema:
' "string", "object", or "list"'
)
- # Parse instructions for the LLM. These are optional but if given,
- # LLMs can use them to improve their output. Commonly these are
- # prepended to the prompt.
- instructions_tag = rail_xml.find("instructions")
- if instructions_tag is not None:
- parse_element(instructions_tag, processed_schema, "instructions")
- processed_schema.exec_opts.instructions = instructions_tag.text
- warnings.warn(
- "The instructions tag has been deprecated"
- " in favor of messages. Please use messages instead.",
- DeprecationWarning,
- )
-
- # Load
- prompt_tag = rail_xml.find("prompt")
- if prompt_tag is not None:
- parse_element(prompt_tag, processed_schema, "prompt")
- processed_schema.exec_opts.prompt = prompt_tag.text
- warnings.warn(
- "The prompt tag has been deprecated"
- " in favor of messages. Please use messages instead.",
- DeprecationWarning,
- )
-
- # If reasking prompt and instructions are provided, add them to the schema.
- reask_prompt = rail_xml.find("reask_prompt")
- if reask_prompt is not None:
- processed_schema.exec_opts.reask_prompt = reask_prompt.text
- warnings.warn(
- "The reask_prompt tag has been deprecated"
- " in favor of reask_messages. Please use reask_messages instead.",
- DeprecationWarning,
- )
-
- reask_instructions = rail_xml.find("reask_instructions")
- if reask_instructions is not None:
- processed_schema.exec_opts.reask_instructions = reask_instructions.text
- warnings.warn(
- "The reask_instructions tag has been deprecated"
- " in favor of reask_messages. Please use reask_messages instead.",
- DeprecationWarning,
- )
-
messages = rail_xml.find("messages")
if messages is not None:
extracted_messages = []
@@ -455,7 +404,7 @@ def rail_string_to_schema(rail_string: str) -> ProcessedSchema:
role = message.attrib.get("role")
content = message.text
extracted_reask_messages.append({"role": role, "content": content})
- processed_schema.exec_opts.messages = extracted_reask_messages
+ processed_schema.exec_opts.reask_messages = extracted_reask_messages
return processed_schema
diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py
index a74c58fe5..38da58c2c 100644
--- a/tests/integration_tests/test_streaming.py
+++ b/tests/integration_tests/test_streaming.py
@@ -199,6 +199,8 @@ class MinSentenceLengthNoOp(BaseModel):
${gr.complete_json_suffix}
"""
+MESSAGES=[{"role": "user", "content": PROMPT}]
+
JSON_LLM_CHUNKS = [
'{"statement":',
' "I am DOING',
@@ -212,19 +214,19 @@ class MinSentenceLengthNoOp(BaseModel):
"guard, expected_validated_output",
[
(
- gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseNoop, messages=MESSAGES),
expected_noop_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseFix, messages=MESSAGES),
expected_fix_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseFilter, messages=MESSAGES),
expected_filter_refrain_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseRefrain, messages=MESSAGES),
expected_filter_refrain_output,
),
],
@@ -268,19 +270,19 @@ def test_streaming_with_openai_callable(
"guard, expected_validated_output",
[
(
- gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseNoop, messages=MESSAGES),
expected_noop_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseFix, messages=MESSAGES),
expected_fix_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseFilter, messages=MESSAGES),
expected_filter_refrain_output,
),
(
- gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT),
+ gd.Guard.from_pydantic(output_class=LowerCaseRefrain, messages=MESSAGES),
expected_filter_refrain_output,
),
],
@@ -345,7 +347,7 @@ def test_streaming_with_openai_chat_callable(
validators=[
MinSentenceLengthValidator(26, 30, on_fail=OnFailAction.NOOP)
],
- prompt=STR_PROMPT,
+ messages=MESSAGES,
),
# each value is a tuple
# first is expected text inside span
diff --git a/tests/unit_tests/actions/test_reask.py b/tests/unit_tests/actions/test_reask.py
index ccf6e77d7..5e629551a 100644
--- a/tests/unit_tests/actions/test_reask.py
+++ b/tests/unit_tests/actions/test_reask.py
@@ -542,13 +542,12 @@ def test_get_reask_prompt(
output_schema = processed_schema.json_schema
exec_options = GuardExecutionOptions(
# Use an XML constant to make existing test cases pass
- prompt="${gr.complete_xml_suffix_v3}"
+ messages=[{"role": "system", "content": "${gr.complete_xml_suffix_v3}"}]
)
(
reask_schema,
- reask_prompt,
- reask_instructions,
+ reask_messages,
) = get_reask_setup(
output_type,
output_schema,
@@ -569,9 +568,8 @@ def test_get_reask_prompt(
# json.dumps(json_example, indent=2),
)
- assert reask_prompt.source == expected_prompt
-
- assert reask_instructions.source == expected_instructions
+ assert str(reask_messages.source[0]["content"]) == expected_instructions
+ assert str(reask_messages.source[1]["content"]) == expected_prompt
### FIXME: Implement once Field Level ReAsk is implemented w/ JSON schema ###
diff --git a/tests/unit_tests/classes/history/test_call.py b/tests/unit_tests/classes/history/test_call.py
index f1f2eb034..a9ce9706e 100644
--- a/tests/unit_tests/classes/history/test_call.py
+++ b/tests/unit_tests/classes/history/test_call.py
@@ -8,6 +8,7 @@
from guardrails.llm_providers import ArbitraryCallable
from guardrails.prompt.instructions import Instructions
from guardrails.prompt.prompt import Prompt
+from guardrails.prompt.messages import Messages
from guardrails.classes.llm.llm_response import LLMResponse
from guardrails.classes.validation.validator_logs import ValidatorLogs
from guardrails.actions.reask import ReAsk
@@ -19,13 +20,11 @@ def test_empty_initialization():
assert call.iterations == Stack()
assert call.inputs == CallInputs()
- assert call.prompt is None
assert call.prompt_params is None
- assert call.compiled_prompt is None
assert call.reask_prompts == Stack()
- assert call.instructions is None
- assert call.compiled_instructions is None
- assert call.reask_instructions == Stack()
+ assert call.messages is None
+ assert call.compiled_messages is None
+ assert len(call.reask_messages) == 0
assert call.logs == Stack()
assert call.tokens_consumed is None
assert call.prompt_tokens_consumed is None
@@ -69,8 +68,12 @@ def custom_llm():
# First Iteration Inputs
iter_llm_api = ArbitraryCallable(llm_api=llm_api)
llm_output = "Hello there!"
- instructions = Instructions(source=instructions)
- iter_prompt = Prompt(source=prompt)
+
+ messages = Messages(source=[
+ {"role": "system", "content": instructions},
+ {"role": "user", "content": prompt},
+ ])
+
num_reasks = 0
metadata = {"some_meta_data": "doesn't actually matter"}
full_schema_reask = False
@@ -78,8 +81,7 @@ def custom_llm():
inputs = Inputs(
llm_api=iter_llm_api,
llm_output=llm_output,
- instructions=instructions,
- prompt=iter_prompt,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=num_reasks,
metadata=metadata,
@@ -121,13 +123,14 @@ def custom_llm():
call_id="mock-call", index=0, inputs=inputs, outputs=first_outputs
)
- second_iter_prompt = Prompt(source="That wasn't quite right. Try again.")
+ second_iter_messages = Messages(source=[
+ {"role":"system", "content":"That wasn't quite right. Try again."}
+ ])
second_inputs = Inputs(
llm_api=iter_llm_api,
llm_output=llm_output,
- instructions=instructions,
- prompt=second_iter_prompt,
+ messages=second_iter_messages,
num_reasks=num_reasks,
metadata=metadata,
full_schema_reask=full_schema_reask,
@@ -171,11 +174,10 @@ def custom_llm():
assert call.prompt == prompt
assert call.prompt_params == prompt_params
- assert call.compiled_prompt == "Respond with a friendly greeting."
- assert call.reask_prompts == Stack(second_iter_prompt.source)
- assert call.instructions == instructions.source
- assert call.compiled_instructions == instructions.source
- assert call.reask_instructions == Stack(instructions.source)
+ # TODO FIX this
+ # assert call.messages == messages.source
+ assert call.compiled_messages[1]["content"] == "Respond with a friendly greeting."
+ assert call.reask_messages == Stack(second_iter_messages.source)
# TODO: Test this in the integration tests
assert call.logs == []
diff --git a/tests/unit_tests/classes/history/test_call_inputs.py b/tests/unit_tests/classes/history/test_call_inputs.py
index 74191683f..407f242b4 100644
--- a/tests/unit_tests/classes/history/test_call_inputs.py
+++ b/tests/unit_tests/classes/history/test_call_inputs.py
@@ -6,14 +6,12 @@ def test_empty_initialization():
# Overrides and additional properties
assert call_inputs.llm_api is None
- assert call_inputs.prompt is None
- assert call_inputs.instructions is None
assert call_inputs.args == []
assert call_inputs.kwargs == {}
# Inherited properties
assert call_inputs.llm_output is None
- assert call_inputs.msg_history is None
+ assert call_inputs.messages is None
assert call_inputs.prompt_params is None
assert call_inputs.num_reasks is None
assert call_inputs.metadata is None
diff --git a/tests/unit_tests/classes/history/test_inputs.py b/tests/unit_tests/classes/history/test_inputs.py
index 4ef221c48..830f11e01 100644
--- a/tests/unit_tests/classes/history/test_inputs.py
+++ b/tests/unit_tests/classes/history/test_inputs.py
@@ -1,7 +1,6 @@
from guardrails.classes.history.inputs import Inputs
from guardrails.llm_providers import OpenAICallable
-from guardrails.prompt.instructions import Instructions
-from guardrails.prompt.prompt import Prompt
+from guardrails.prompt.messages import Messages
# Guard against regressions in pydantic BaseModel
@@ -10,9 +9,7 @@ def test_empty_initialization():
assert inputs.llm_api is None
assert inputs.llm_output is None
- assert inputs.instructions is None
- assert inputs.prompt is None
- assert inputs.msg_history is None
+ assert inputs.messages is None
assert inputs.prompt_params is None
assert inputs.num_reasks is None
assert inputs.metadata is None
@@ -22,11 +19,12 @@ def test_empty_initialization():
def test_non_empty_initialization():
llm_api = OpenAICallable(text="Respond with a greeting.")
llm_output = "Hello there!"
- instructions = Instructions(source="You are a greeting bot.")
- prompt = Prompt(source="Respond with a ${greeting_type} greeting.")
- msg_history = [
- {"some_key": "doesn't actually matter because this isn't that strongly typed"}
- ]
+ messages = Messages(
+ source=[
+ {"role": "system", "content": "You are a greeting bot."},
+ {"role": "user", "content": "Respond with a ${greeting_type} greeting."},
+ ]
+ )
prompt_params = {"greeting_type": "friendly"}
num_reasks = 0
metadata = {"some_meta_data": "doesn't actually matter"}
@@ -35,9 +33,7 @@ def test_non_empty_initialization():
inputs = Inputs(
llm_api=llm_api,
llm_output=llm_output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=num_reasks,
metadata=metadata,
@@ -48,12 +44,8 @@ def test_non_empty_initialization():
assert inputs.llm_api == llm_api
assert inputs.llm_output is not None
assert inputs.llm_output == llm_output
- assert inputs.instructions is not None
- assert inputs.instructions == instructions
- assert inputs.prompt is not None
- assert inputs.prompt == prompt
- assert inputs.msg_history is not None
- assert inputs.msg_history == msg_history
+ assert inputs.messages is not None
+ assert inputs.messages == messages
assert inputs.prompt_params is not None
assert inputs.prompt_params == prompt_params
assert inputs.num_reasks is not None
diff --git a/tests/unit_tests/classes/history/test_iteration.py b/tests/unit_tests/classes/history/test_iteration.py
index 38bc05d73..802e321a1 100644
--- a/tests/unit_tests/classes/history/test_iteration.py
+++ b/tests/unit_tests/classes/history/test_iteration.py
@@ -6,6 +6,7 @@
from guardrails.llm_providers import OpenAICallable
from guardrails.prompt.instructions import Instructions
from guardrails.prompt.prompt import Prompt
+from guardrails.prompt.messages import Messages
from guardrails.classes.llm.llm_response import LLMResponse
from guardrails.classes.validation.validator_logs import ValidatorLogs
from guardrails.actions.reask import FieldReAsk
@@ -40,11 +41,10 @@ def test_non_empty_initialization():
# Inputs
llm_api = OpenAICallable(text="Respond with a greeting.")
llm_output = "Hello there!"
- instructions = Instructions(source="You are a greeting bot.")
- prompt = Prompt(source="Respond with a ${greeting_type} greeting.")
- msg_history = [
- {"some_key": "doesn't actually matter because this isn't that strongly typed"}
- ]
+ messages = Messages(source=[
+ {"role": "system", "content": "You are a greeting bot."},
+ {"role": "user", "content": "Respond with a ${greeting_type} greeting."}
+ ])
prompt_params = {"greeting_type": "friendly"}
num_reasks = 0
metadata = {"some_meta_data": "doesn't actually matter"}
@@ -53,9 +53,7 @@ def test_non_empty_initialization():
inputs = Inputs(
llm_api=llm_api,
llm_output=llm_output,
- instructions=instructions,
- prompt=prompt,
- msg_history=msg_history,
+ messages=messages,
prompt_params=prompt_params,
num_reasks=num_reasks,
metadata=metadata,
diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py
index f3e951929..7b2916fab 100644
--- a/tests/unit_tests/test_guard.py
+++ b/tests/unit_tests/test_guard.py
@@ -521,9 +521,9 @@ def test_validate():
# Should still only use the output validators to validate the output
guard: Guard = (
Guard()
- .use(OneLine, on="prompt")
- .use(LowerCase, on="instructions")
- .use(UpperCase, on="msg_history")
+ .use(OneLine, on="messages")
+ .use(LowerCase, on="messages")
+ .use(UpperCase, on="messages")
.use(LowerCase, on="output", on_fail=OnFailAction.FIX)
.use(TwoWords, on="output")
.use(ValidLength, 0, 12, on="output")
@@ -547,9 +547,7 @@ def test_validate():
def test_use_and_use_many():
guard: Guard = (
Guard()
- .use_many(OneLine(), LowerCase(), on="prompt")
- .use(UpperCase, on="instructions")
- .use(LowerCase, on="msg_history")
+ .use_many(OneLine(), LowerCase(), on="messages")
.use_many(
TwoWords(on_fail=OnFailAction.REASK),
ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
@@ -557,20 +555,12 @@ def test_use_and_use_many():
)
)
- # Check schemas for prompt, instructions and msg_history validators
- prompt_validators = guard._validator_map.get("prompt", [])
+ # Check schemas for messages validators
+ prompt_validators = guard._validator_map.get("messages", [])
assert len(prompt_validators) == 2
assert prompt_validators[0].__class__.__name__ == "OneLine"
assert prompt_validators[1].__class__.__name__ == "LowerCase"
- instructions_validators = guard._validator_map.get("instructions", [])
- assert len(instructions_validators) == 1
- assert instructions_validators[0].__class__.__name__ == "UpperCase"
-
- msg_history_validators = guard._validator_map.get("msg_history", [])
- assert len(msg_history_validators) == 1
- assert msg_history_validators[0].__class__.__name__ == "LowerCase"
-
# Check guard for validators
assert len(guard._validators) == 6
@@ -590,9 +580,8 @@ def test_use_and_use_many():
with pytest.warns(UserWarning):
guard: Guard = (
Guard()
- .use_many(OneLine(), LowerCase(), on="prompt")
- .use(UpperCase, on="instructions")
- .use(LowerCase, on="msg_history")
+ .use_many(OneLine(), LowerCase(), on="messages")
+ .use(UpperCase, on="messages")
.use_many(
TwoWords(on_fail=OnFailAction.REASK),
ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py
index 7c933ecea..805563c93 100644
--- a/tests/unit_tests/test_prompt.py
+++ b/tests/unit_tests/test_prompt.py
@@ -6,8 +6,8 @@
from pydantic import BaseModel, Field
import guardrails as gd
-from guardrails.prompt.instructions import Instructions
from guardrails.prompt.prompt import Prompt
+from guardrails.prompt.messages import Messages
from guardrails.utils.constants import constants
from guardrails.utils.prompt_utils import prompt_content_for_schema
@@ -15,28 +15,34 @@
PROMPT = "Extract a string from the text"
-REASK_PROMPT = """
+REASK_MESSAGES = [
+ {
+ "role": "user",
+ "content": """
Please try that again, extract a string from the text
${xml_output_schema}
${previous_response}
-"""
+""",
+ }
+]
SIMPLE_RAIL_SPEC = f"""
-
+
+
{INSTRUCTIONS}
-
-
-
+
+
{PROMPT}
-
+
+
"""
@@ -46,17 +52,18 @@
-
+
+
${user_instructions}
-
-
-
+
+
${user_prompt}
-
+
+
"""
@@ -66,18 +73,19 @@
-
-You are a helpful bot, who answers only with valid JSON
+
+
-
+You are a helpful bot, who answers only with valid JSON
-
+
+
Extract a string from the text
${gr.complete_json_suffix_v2}
-
+
"""
@@ -122,6 +130,34 @@
"""
+RAIL_WITH_REASK_MESSAGES = """
+
+
+
+
+
+
+You are a helpful bot, who answers only with valid JSON
+
+
+${gr.complete_json_suffix_v2}
+
+
+
+
+
+Please try that again, extract a string from the text
+${xml_output_schema}
+${previous_response}
+
+
+
+
+"""
+
+
RAIL_WITH_REASK_INSTRUCTIONS = """