diff --git a/guardrails/actions/reask.py b/guardrails/actions/reask.py index 2ca1f711b..629cab514 100644 --- a/guardrails/actions/reask.py +++ b/guardrails/actions/reask.py @@ -247,7 +247,7 @@ def get_reask_setup_for_string( validation_response: Optional[Union[str, List, Dict, ReAsk]] = None, prompt_params: Optional[Dict[str, Any]] = None, exec_options: Optional[GuardExecutionOptions] = None, -) -> Tuple[Dict[str, Any], Prompt, Instructions]: +) -> Tuple[Dict[str, Any], Messages]: prompt_params = prompt_params or {} exec_options = exec_options or GuardExecutionOptions() @@ -300,7 +300,10 @@ def get_reask_setup_for_string( messages = Messages(exec_options.reask_messages) if messages is None: messages = Messages( - [{"role": "system", "content": "You are a helpful assistant."}] + [ + {"role": "system", "content": instructions}, + {"role": "user", "content": prompt}, + ] ) messages = messages.format( @@ -309,7 +312,17 @@ def get_reask_setup_for_string( **prompt_params, ) - return output_schema, prompt, instructions + return output_schema, messages + + +def get_original_messages(exec_options: GuardExecutionOptions) -> List[Dict[str, Any]]: + exec_options = exec_options or GuardExecutionOptions() + original_messages = exec_options.messages or [] + messages_prompt = next( + (h.get("content") for h in original_messages if isinstance(h, dict)), + "", + ) + return messages_prompt def get_original_prompt(exec_options: Optional[GuardExecutionOptions] = None) -> str: @@ -338,7 +351,7 @@ def get_reask_setup_for_json( use_full_schema: Optional[bool] = False, prompt_params: Optional[Dict[str, Any]] = None, exec_options: Optional[GuardExecutionOptions] = None, -) -> Tuple[Dict[str, Any], Prompt, Instructions]: +) -> Tuple[Dict[str, Any], Messages]: reask_schema = output_schema is_skeleton_reask = not any(isinstance(reask, FieldReAsk) for reask in reasks) is_nonparseable_reask = any( @@ -347,12 +360,10 @@ def get_reask_setup_for_json( error_messages = {} prompt_params = prompt_params or {} exec_options = exec_options or GuardExecutionOptions() - original_prompt = get_original_prompt(exec_options) - use_xml = prompt_uses_xml(original_prompt) + original_messages = get_original_messages(exec_options) + use_xml = prompt_uses_xml(original_messages) reask_prompt_template = None - if exec_options.reask_prompt: - reask_prompt_template = Prompt(exec_options.reask_prompt) if is_nonparseable_reask: if reask_prompt_template is None: @@ -461,31 +472,26 @@ def reask_decoder(obj: ReAsk): **prompt_params, ) - instructions = None - if exec_options.reask_instructions: - instructions = Instructions(exec_options.reask_instructions) - else: - instructions_const = ( - constants["high_level_xml_instructions"] - if use_xml - else constants["high_level_json_instructions"] - ) - instructions = Instructions(instructions_const) + instructions_const = ( + constants["high_level_xml_instructions"] + if use_xml + else constants["high_level_json_instructions"] + ) + instructions = Instructions(instructions_const) instructions = instructions.format(**prompt_params) - # TODO: enable this in 0.6.0 - # messages = None - # if exec_options.reask_messages: - # messages = Messages(exec_options.reask_messages) - # else: - # messages = Messages( - # [ - # {"role": "system", "content": instructions}, - # {"role": "user", "content": prompt}, - # ] - # ) + messages = None + if exec_options.reask_messages: + messages = Messages(exec_options.reask_messages) + else: + messages = Messages( + [ + {"role": "system", "content": instructions}, + {"role": "user", "content": prompt}, + ] + ) - return reask_schema, prompt, instructions + return reask_schema, messages def get_reask_setup( @@ -499,7 +505,7 @@ def get_reask_setup( use_full_schema: Optional[bool] = False, prompt_params: Optional[Dict[str, Any]] = None, exec_options: Optional[GuardExecutionOptions] = None, -) -> Tuple[Dict[str, Any], Prompt, Instructions]: +) -> Tuple[Dict[str, Any], Messages]: prompt_params = prompt_params or {} exec_options = exec_options or GuardExecutionOptions() diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index f8f12d085..1c727d8b4 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -92,11 +92,8 @@ def from_pydantic( cls, output_class: ModelOrListOfModels, *, - prompt: Optional[str] = None, - instructions: Optional[str] = None, num_reasks: Optional[int] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, + messages: Optional[List[Dict]] = None, reask_messages: Optional[List[Dict]] = None, tracer: Optional[Tracer] = None, name: Optional[str] = None, @@ -104,11 +101,8 @@ def from_pydantic( ): guard = super().from_pydantic( output_class, - prompt=prompt, - instructions=instructions, + messages=messages, num_reasks=num_reasks, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, reask_messages=reask_messages, tracer=tracer, name=name, @@ -125,10 +119,8 @@ def from_string( validators: Sequence[Validator], *, string_description: Optional[str] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, + messages: Optional[List[Dict]] = None, + reask_messages: Optional[List[Dict]] = None, num_reasks: Optional[int] = None, tracer: Optional[Tracer] = None, name: Optional[str] = None, @@ -137,10 +129,8 @@ def from_string( guard = super().from_string( validators, string_description=string_description, - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, + messages=messages, + reask_messages=reask_messages, num_reasks=num_reasks, tracer=tracer, name=name, @@ -178,9 +168,7 @@ async def _execute( llm_output: Optional[str] = None, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, metadata: Optional[Dict], full_schema_reask: Optional[bool] = None, **kwargs, @@ -192,10 +180,8 @@ async def _execute( self._fill_validator_map() self._fill_validators() metadata = metadata or {} - if not llm_output and llm_api and not (prompt or msg_history): - raise RuntimeError( - "'prompt' or 'msg_history' must be provided in order to call an LLM!" - ) + if not llm_output and llm_api and not (messages): + raise RuntimeError("'messages' must be provided in order to call an LLM!") # check if validator requirements are fulfilled missing_keys = verify_metadata_requirements(metadata, self._validators) if missing_keys: @@ -210,9 +196,7 @@ async def __exec( llm_output: Optional[str] = None, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, **kwargs, @@ -245,14 +229,6 @@ async def __exec( ("guard_id", self.id), ("user_id", self._user_id), ("llm_api", llm_api_str), - ( - "custom_reask_prompt", - self._exec_opts.reask_prompt is not None, - ), - ( - "custom_reask_instructions", - self._exec_opts.reask_instructions is not None, - ), ( "custom_reask_messages", self._exec_opts.reask_messages is not None, @@ -273,14 +249,10 @@ async def __exec( "This should never happen." ) - input_prompt = prompt or self._exec_opts.prompt - input_instructions = instructions or self._exec_opts.instructions + input_messages = messages or self._exec_opts.messages call_inputs = CallInputs( llm_api=llm_api, - prompt=input_prompt, - instructions=input_instructions, - msg_history=msg_history, - prompt_params=prompt_params, + messages=input_messages, num_reasks=self._num_reasks, metadata=metadata, full_schema_reask=full_schema_reask, @@ -298,9 +270,7 @@ async def __exec( prompt_params=prompt_params, metadata=metadata, full_schema_reask=full_schema_reask, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, *args, **kwargs, ) @@ -315,9 +285,7 @@ async def __exec( llm_output=llm_output, prompt_params=prompt_params, num_reasks=self._num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, call_log=call_log, @@ -343,9 +311,7 @@ async def __exec( llm_output=llm_output, prompt_params=prompt_params, num_reasks=num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, *args, @@ -362,9 +328,7 @@ async def _exec( num_reasks: int = 0, # Should be defined at this point metadata: Dict, # Should be defined at this point full_schema_reask: bool = False, # Should be defined at this point - prompt: Optional[str], - instructions: Optional[str], - msg_history: Optional[List[Dict]], + messages: List[Dict] = None, **kwargs, ) -> Union[ ValidationOutcome[OT], @@ -377,9 +341,7 @@ async def _exec( llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate) prompt_params: The parameters to pass to the prompt.format() method. num_reasks: The max times to re-ask the LLM for invalid output. - prompt: The prompt to use for the LLM. - instructions: Instructions for chat models. - msg_history: The message history to pass to the LLM. + messages: List of messages for llm to respond to. metadata: Metadata to pass to the validators. full_schema_reask: When reasking, whether to regenerate the full schema or just the incorrect values. @@ -396,9 +358,7 @@ async def _exec( output_schema=self.output_schema.to_dict(), num_reasks=num_reasks, validation_map=self._validator_map, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, api=api, metadata=metadata, output=llm_output, @@ -418,9 +378,7 @@ async def _exec( output_schema=self.output_schema.to_dict(), num_reasks=num_reasks, validation_map=self._validator_map, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, api=api, metadata=metadata, output=llm_output, @@ -441,9 +399,7 @@ async def __call__( *args, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = 1, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, **kwargs, @@ -460,9 +416,7 @@ async def __call__( (e.g. openai.completions.create or openai.chat.completions.create) prompt_params: The parameters to pass to the prompt.format() method. num_reasks: The max times to re-ask the LLM for invalid output. - prompt: The prompt to use for the LLM. - instructions: Instructions for chat models. - msg_history: The message history to pass to the LLM. + messages: The messages to pass to the LLM. metadata: Metadata to pass to the validators. full_schema_reask: When reasking, whether to regenerate the full schema or just the incorrect values. @@ -473,25 +427,18 @@ async def __call__( The raw text output from the LLM and the validated output. """ - instructions = instructions or self._exec_opts.instructions - prompt = prompt or self._exec_opts.prompt - msg_history = msg_history or kwargs.pop("messages", None) or [] - - if prompt is None: - if msg_history is not None and not len(msg_history): - raise RuntimeError( - "You must provide a prompt if msg_history is empty. " - "Alternatively, you can provide a prompt in the Schema constructor." - ) + messages = messages or self._exec_opts.messages or [] + if messages is not None and not len(messages): + raise RuntimeError( + "You must provide messages to the LLM in order to call it." + ) return await self._execute( *args, llm_api=llm_api, prompt_params=prompt_params, num_reasks=num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, **kwargs, @@ -534,14 +481,8 @@ async def parse( if llm_api is None else 1 ) - default_prompt = self._exec_opts.prompt if llm_api is not None else None - prompt = kwargs.pop("prompt", default_prompt) - - default_instructions = self._exec_opts.instructions if llm_api else None - instructions = kwargs.pop("instructions", default_instructions) - - default_msg_history = self._exec_opts.msg_history if llm_api else None - msg_history = kwargs.pop("msg_history", default_msg_history) + default_messages = self._exec_opts.messages if llm_api else None + messages = kwargs.pop("messages", default_messages) return await self._execute( # type: ignore *args, @@ -549,9 +490,7 @@ async def parse( llm_api=llm_api, prompt_params=prompt_params, num_reasks=final_num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, **kwargs, diff --git a/guardrails/classes/execution/guard_execution_options.py b/guardrails/classes/execution/guard_execution_options.py index 1230cf7d8..cabdddbcf 100644 --- a/guardrails/classes/execution/guard_execution_options.py +++ b/guardrails/classes/execution/guard_execution_options.py @@ -4,11 +4,6 @@ @dataclass class GuardExecutionOptions: - prompt: Optional[str] = None - instructions: Optional[str] = None - msg_history: Optional[List[Dict]] = None messages: Optional[List[Dict]] = None - reask_prompt: Optional[str] = None - reask_instructions: Optional[str] = None reask_messages: Optional[List[Dict]] = None num_reasks: Optional[int] = None diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index cf79a6e2f..3eacc46d4 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -15,7 +15,6 @@ from guardrails.classes.generic.arbitrary_model import ArbitraryModel from guardrails.classes.validation.validation_result import ValidationResult from guardrails.constants import error_status, fail_status, not_run_status, pass_status -from guardrails.prompt.instructions import Instructions from guardrails.prompt.prompt import Prompt from guardrails.prompt.messages import Messages from guardrails.classes.validation.validator_logs import ValidatorLogs @@ -116,22 +115,22 @@ def reask_prompts(self) -> Stack[Optional[str]]: return Stack() @property - def instructions(self) -> Optional[str]: - """The instructions as provided by the user when initializing or - calling the Guard.""" - return self.inputs.instructions + def messages(self) -> Optional[List[Dict[str, Any]]]: + """The messages as provided by the user when initializing or calling the + Guard.""" + return self.inputs.messages @property - def compiled_instructions(self) -> Optional[str]: - """The initial compiled instructions that were passed to the LLM on the + def compiled_messages(self) -> Optional[List[Dict[str, Any]]]: + """The initial compiled messages that were passed to the LLM on the first call.""" if self.iterations.empty(): return None - initial_inputs = self.iterations.first.inputs # type: ignore - instructions: Instructions = initial_inputs.instructions # type: ignore + initial_inputs = self.iterations.first.inputs + messages = initial_inputs.messages prompt_params = initial_inputs.prompt_params or {} - if instructions is not None: - return instructions.format(**prompt_params).source + if messages is not None: + return messages.format(**prompt_params).source @property def reask_messages(self) -> Stack[Messages]: @@ -145,27 +144,7 @@ def reask_messages(self) -> Stack[Messages]: reasks.remove(initial_messages) # type: ignore return Stack( *[ - r.inputs.messages if r.inputs.messages is not None else None - for r in reasks - ] - ) - - return Stack() - - @property - def reask_instructions(self) -> Stack[str]: - """The compiled instructions used during reasks. - - Does not include the initial instructions. - """ - if self.iterations.length > 0: - reasks = self.iterations.copy() - reasks.remove(reasks.first) # type: ignore - return Stack( - *[ - r.inputs.instructions.source - if r.inputs.instructions is not None - else None + r.inputs.messages.source if r.inputs.messages is not None else None for r in reasks ] ) diff --git a/guardrails/classes/history/call_inputs.py b/guardrails/classes/history/call_inputs.py index 4aa2363f8..5f22a817c 100644 --- a/guardrails/classes/history/call_inputs.py +++ b/guardrails/classes/history/call_inputs.py @@ -28,11 +28,8 @@ class CallInputs(Inputs, ICallInputs, ArbitraryModel): "during Guard.__call__ or Guard.parse.", default=None, ) - prompt: Optional[str] = Field( - description="The prompt string as provided by the user.", default=None - ) - instructions: Optional[str] = Field( - description="The instructions string as provided by the user.", default=None + messages: Optional[List[Dict[str, Any]]] = Field( + description="The messages as provided by the user.", default=None ) args: List[Any] = Field( description="Additional arguments for the LLM as provided by the user.", diff --git a/guardrails/classes/history/inputs.py b/guardrails/classes/history/inputs.py index d44d63628..c8b1e05d3 100644 --- a/guardrails/classes/history/inputs.py +++ b/guardrails/classes/history/inputs.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from pydantic import Field @@ -18,10 +18,7 @@ class Inputs(IInputs, ArbitraryModel): for calling the LLM. llm_output (Optional[str]): The string output from an external LLM call provided by the user via Guard.parse. - instructions (Optional[Instructions]): The constructed - Instructions class for chat model calls. - prompt (Optional[Prompt]): The constructed Prompt class. - msg_history (Optional[List[Dict]]): The message history + messages (Optional[List[Dict]]): The message history provided by the user for chat model calls. prompt_params (Optional[Dict]): The parameters provided by the user that will be formatted into the final LLM prompt. @@ -42,19 +39,8 @@ class Inputs(IInputs, ArbitraryModel): "provided by the user via Guard.parse.", default=None, ) - instructions: Optional[Instructions] = Field( - description="The constructed Instructions class for chat model calls.", - default=None, - ) - prompt: Optional[Prompt] = Field( - description="The constructed Prompt class.", default=None - ) - msg_history: Optional[List[Dict]] = Field( - description="The message history provided by the user for chat model calls.", - default=None, - ) - messages: Optional[List[Messages]] = Field( - description="The message history provided by the user for chat model calls.", + messages: Optional[Messages] = Field( + description="The messages provided by the user for chat model calls.", default=None, ) prompt_params: Optional[Dict] = Field( diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py index 45c60baab..8760d2e49 100644 --- a/guardrails/classes/history/iteration.py +++ b/guardrails/classes/history/iteration.py @@ -188,63 +188,36 @@ def status(self) -> str: @property def rich_group(self) -> Group: - def create_msg_history_table( - msg_history: Optional[List[Dict[str, Prompt]]], + def create_messages_table( + messages: Optional[List[Dict[str, Prompt]]], ) -> Union[str, Table]: - if msg_history is None: - return "No message history." + if messages is None: + return "No messages." table = Table(show_lines=True) table.add_column("Role", justify="right", no_wrap=True) table.add_column("Content") - for msg in msg_history: - table.add_row(str(msg["role"]), msg["content"].source) + for msg in messages: + if isinstance(msg["content"], str): + table.add_row(str(msg["role"]), msg["content"]) + else: + table.add_row(str(msg["role"]), msg["content"].source) return table - table = create_msg_history_table(self.inputs.msg_history) - - if self.inputs.instructions is not None: - return Group( - Panel( - self.inputs.prompt.source if self.inputs.prompt else "No prompt", - title="Prompt", - style="on #F0F8FF", - ), - Panel( - self.inputs.instructions.source, - title="Instructions", - style="on #FFF0F2", - ), - Panel(table, title="Message History", style="on #E7DFEB"), - Panel( - self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC" - ), - Panel( - pretty_repr(self.validation_response), - title="Validated Output", - style="on #F0FFF0", - ), - ) - else: - return Group( - Panel( - self.inputs.prompt.source if self.inputs.prompt else "No prompt", - title="Prompt", - style="on #F0F8FF", - ), - Panel(table, title="Message History", style="on #E7DFEB"), - Panel( - self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC" - ), - Panel( - self.validation_response - if isinstance(self.validation_response, str) - else pretty_repr(self.validation_response), - title="Validated Output", - style="on #F0FFF0", - ), - ) + table = create_messages_table(self.inputs.messages) + + return Group( + Panel(table, title="Messages", style="on #E7DFEB"), + Panel(self.raw_output or "", title="Raw LLM Output", style="on #F5F5DC"), + Panel( + self.validation_response + if isinstance(self.validation_response, str) + else pretty_repr(self.validation_response), + title="Validated Output", + style="on #F0FFF0", + ), + ) def __str__(self) -> str: return pretty_repr(self) diff --git a/guardrails/guard.py b/guardrails/guard.py index fcd40f022..84a25bcd4 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -310,26 +310,17 @@ def _fill_exec_opts( self, *, num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, + messages: Optional[List[Dict]] = None, + reask_messages: Optional[List[Dict]] = None, **kwargs, # noqa ): """Backfill execution options from kwargs.""" if num_reasks is not None: self._exec_opts.num_reasks = num_reasks - if prompt is not None: - self._exec_opts.prompt = prompt - if instructions is not None: - self._exec_opts.instructions = instructions - if msg_history is not None: - self._exec_opts.msg_history = msg_history - if reask_prompt is not None: - self._exec_opts.reask_prompt = reask_prompt - if reask_instructions is not None: - self._exec_opts.reask_instructions = reask_instructions + if messages is not None: + self._exec_opts.messages = messages + if reask_messages is not None: + self._exec_opts.reask_messages = reask_messages @classmethod def _from_rail_schema( @@ -402,6 +393,7 @@ def from_rail( cls._set_tracer(cls, tracer) # type: ignore schema = rail_file_to_schema(rail_file) + print("==== rail schema", schema) return cls._from_rail_schema( schema, rail=rail_file, @@ -451,6 +443,7 @@ def from_rail_string( cls._set_tracer(cls, tracer) # type: ignore schema = rail_string_to_schema(rail_string) + return cls._from_rail_schema( schema, rail=rail_string, @@ -465,11 +458,7 @@ def from_pydantic( cls, output_class: ModelOrListOfModels, *, - prompt: Optional[str] = None, - instructions: Optional[str] = None, num_reasks: Optional[int] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, reask_messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None, tracer: Optional[Tracer] = None, @@ -483,10 +472,7 @@ def from_pydantic( Args: output_class: (Union[Type[BaseModel], List[Type[BaseModel]]]): The pydantic model that describes the desired structure of the output. - prompt (str, optional): The prompt used to generate the string. Defaults to None. - instructions (str, optional): Instructions for chat models. Defaults to None. - reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None. - reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None. + messages: (List[Dict], optional): A list of messages to send to the llm. Defaults to None. reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None. num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None. @@ -506,31 +492,14 @@ def from_pydantic( DeprecationWarning, ) - if reask_instructions: - warnings.warn( - "reask_instructions is deprecated and will be removed in 0.6.x!" - "Please be prepared to set reask_messages instead.", - DeprecationWarning, - ) - if reask_prompt: - warnings.warn( - "reask_prompt is deprecated and will be removed in 0.6.x!" - "Please be prepared to set reask_messages instead.", - DeprecationWarning, - ) - # We have to set the tracer in the ContextStore before the Rail, # and therefore the Validators, are initialized cls._set_tracer(cls, tracer) # type: ignore schema = pydantic_model_to_schema(output_class) exec_opts = GuardExecutionOptions( - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, - reask_messages=reask_messages, messages=messages, + reask_messages=reask_messages, ) guard = cls( name=name, @@ -566,10 +535,6 @@ def from_string( validators: Sequence[Validator], *, string_description: Optional[str] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, reask_messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None, num_reasks: Optional[int] = None, @@ -582,28 +547,13 @@ def from_string( Args: validators: (List[Validator]): The list of validators to apply to the string output. string_description (str, optional): A description for the string to be generated. Defaults to None. - prompt (str, optional): The prompt used to generate the string. Defaults to None. - instructions (str, optional): Instructions for chat models. Defaults to None. - reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None. - reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None. + messages: (List[Dict], optional): A list of messages to send to the llm. Defaults to None. reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None. num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None. name (str, optional): A unique name for this Guard. Defaults to `gr-` + the object id. description (str, optional): A description for this Guard. Defaults to None. """ # noqa - if reask_instructions: - warnings.warn( - "reask_instructions is deprecated and will be removed in 0.6.x!" - "Please be prepared to set reask_messages instead.", - DeprecationWarning, - ) - if reask_prompt: - warnings.warn( - "reask_prompt is deprecated and will be removed in 0.6.x!" - "Please be prepared to set reask_messages instead.", - DeprecationWarning, - ) if num_reasks: warnings.warn( @@ -623,10 +573,6 @@ def from_string( list(validators), type=SimpleTypes.STRING, description=string_description ) exec_opts = GuardExecutionOptions( - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, reask_messages=reask_messages, messages=messages, ) @@ -653,11 +599,8 @@ def _execute( llm_output: Optional[str] = None, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, + messages: Optional[List[Dict]] = None, + reask_messages: Optional[List[Dict]] = None, metadata: Optional[Dict], full_schema_reask: Optional[bool] = None, **kwargs, @@ -666,17 +609,13 @@ def _execute( self._fill_validators() self._fill_exec_opts( num_reasks=num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, + messages=messages, + reask_messages=reask_messages, ) metadata = metadata or {} - if not llm_output and llm_api and not (prompt or msg_history): - raise RuntimeError( - "'prompt' or 'msg_history' must be provided in order to call an LLM!" - ) + print("==== _execute messages", messages) + if not (messages) and llm_api: + raise RuntimeError("'messages' must be provided in order to call an LLM!") # check if validator requirements are fulfilled missing_keys = verify_metadata_requirements(metadata, self._validators) @@ -692,9 +631,7 @@ def __exec( llm_output: Optional[str] = None, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, **kwargs, @@ -723,14 +660,6 @@ def __exec( ("guard_id", self.id), ("user_id", self._user_id), ("llm_api", llm_api_str if llm_api_str else "None"), - ( - "custom_reask_prompt", - self._exec_opts.reask_prompt is not None, - ), - ( - "custom_reask_instructions", - self._exec_opts.reask_instructions is not None, - ), ( "custom_reask_messages", self._exec_opts.reask_messages is not None, @@ -751,13 +680,10 @@ def __exec( "This should never happen." ) - input_prompt = prompt or self._exec_opts.prompt - input_instructions = instructions or self._exec_opts.instructions + input_messages = messages or self._exec_opts.messages call_inputs = CallInputs( llm_api=llm_api, - prompt=input_prompt, - instructions=input_instructions, - msg_history=msg_history, + messages=input_messages, prompt_params=prompt_params, num_reasks=self._num_reasks, metadata=metadata, @@ -776,9 +702,7 @@ def __exec( prompt_params=prompt_params, metadata=metadata, full_schema_reask=full_schema_reask, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, *args, **kwargs, ) @@ -787,14 +711,13 @@ def __exec( set_scope(str(object_id(call_log))) self.history.push(call_log) # Otherwise, call the LLM synchronously + print("====executing messages", messages) return self._exec( llm_api=llm_api, llm_output=llm_output, prompt_params=prompt_params, num_reasks=self._num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, call_log=call_log, @@ -817,9 +740,7 @@ def __exec( llm_output=llm_output, prompt_params=prompt_params, num_reasks=num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, *args, @@ -836,9 +757,7 @@ def _exec( num_reasks: int = 0, # Should be defined at this point metadata: Dict, # Should be defined at this point full_schema_reask: bool = False, # Should be defined at this point - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, **kwargs, ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: api = None @@ -858,9 +777,7 @@ def _exec( output_schema=self.output_schema.to_dict(), num_reasks=num_reasks, validation_map=self._validator_map, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, api=api, metadata=metadata, output=llm_output, @@ -877,9 +794,7 @@ def _exec( output_schema=self.output_schema.to_dict(), num_reasks=num_reasks, validation_map=self._validator_map, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, api=api, metadata=metadata, output=llm_output, @@ -897,9 +812,6 @@ def __call__( *args, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = 1, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, **kwargs, @@ -911,9 +823,6 @@ def __call__( (e.g. openai.completions.create or openai.Completion.acreate) prompt_params: The parameters to pass to the prompt.format() method. num_reasks: The max times to re-ask the LLM for invalid output. - prompt: The prompt to use for the LLM. - instructions: Instructions for chat models. - msg_history: The message history to pass to the LLM. metadata: Metadata to pass to the validators. full_schema_reask: When reasking, whether to regenerate the full schema or just the incorrect values. @@ -923,24 +832,20 @@ def __call__( Returns: ValidationOutcome """ - instructions = instructions or self._exec_opts.instructions - prompt = prompt or self._exec_opts.prompt - msg_history = msg_history or kwargs.get("messages", None) or [] - if prompt is None: - if msg_history is not None and not len(msg_history): - raise RuntimeError( - "You must provide a prompt if msg_history is empty. " - "Alternatively, you can provide a prompt in the Schema constructor." - ) - + messages = kwargs.pop("messages", None) or self._exec_opts.messages or [] + print("==== call kwargs", kwargs) + print("==== call messages", messages) + if messages is not None and not len(messages): + raise RuntimeError( + "You must provide messages " + "Alternatively, you can provide a messages in the Schema constructor." + ) return self._execute( *args, llm_api=llm_api, prompt_params=prompt_params, num_reasks=num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, **kwargs, @@ -981,14 +886,8 @@ def parse( if llm_api is None else 1 ) - default_prompt = self._exec_opts.prompt if llm_api else None - prompt = kwargs.pop("prompt", default_prompt) - - default_instructions = self._exec_opts.instructions if llm_api else None - instructions = kwargs.pop("instructions", default_instructions) - - default_msg_history = self._exec_opts.msg_history if llm_api else None - msg_history = kwargs.pop("msg_history", default_msg_history) + default_messages = self._exec_opts.messages if llm_api else None + messages = kwargs.pop("messages", default_messages) return self._execute( # type: ignore # streams are supported for parse *args, @@ -996,9 +895,7 @@ def parse( llm_api=llm_api, prompt_params=prompt_params, num_reasks=final_num_reasks, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, metadata=metadata, full_schema_reask=full_schema_reask, **kwargs, @@ -1020,14 +917,12 @@ def error_spans_in_output(self) -> List[ErrorSpan]: def __add_validator(self, validator: Validator, on: str = "output"): if on not in [ "output", - "prompt", - "instructions", - "msg_history", + "messages", ] and not on.startswith("$"): warnings.warn( f"Unusual 'on' value: {on}!" "This value is typically one of " - "'output', 'prompt', 'instructions', 'msg_history') " + "'output', 'messages') " "or a JSON path starting with '$.'", UserWarning, ) @@ -1063,9 +958,7 @@ def use( ) -> "Guard": """Use a validator to validate either of the following: - The output of an LLM request - - The prompt - - The instructions - - The message history + - The messages Args: validator: The validator to use. Either the class or an instance. @@ -1248,16 +1141,10 @@ def _call_server( if llm_api is not None: payload["llmApi"] = get_llm_api_enum(llm_api, *args, **kwargs) - if not payload.get("prompt"): - payload["prompt"] = self._exec_opts.prompt - if not payload.get("instructions"): - payload["instructions"] = self._exec_opts.instructions - if not payload.get("msg_history"): - payload["msg_history"] = self._exec_opts.msg_history - if not payload.get("reask_prompt"): - payload["reask_prompt"] = self._exec_opts.reask_prompt - if not payload.get("reask_instructions"): - payload["reask_instructions"] = self._exec_opts.reask_instructions + if not payload.get("messages"): + payload["messages"] = self._exec_opts.messages + if not payload.get("reask_messages"): + payload["reask_messages"] = self._exec_opts.reask_messages should_stream = kwargs.get("stream", False) if should_stream: diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index 2eb3d8f26..c656259ab 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -309,170 +309,6 @@ def _invoke_llm( ) -class CohereCallable(PromptCallableBase): - def _invoke_llm( - self, prompt: str, client_callable: Any, model: str, *args, **kwargs - ) -> LLMResponse: - """To use cohere for guardrails, do ``` client = - cohere.Client(api_key=...) - - raw_llm_response, validated_response, *rest = guard( - client.generate, - prompt_params={...}, - model="command-nightly", - ... - ) - ``` - """ # noqa - warnings.warn( - "This callable is deprecated in favor of passing " - "no callable and the model argument which utilizes LiteLLM" - "for example guard(model='command-r', messages=[...], ...)", - DeprecationWarning, - ) - - trace_input_messages = chat_prompt(prompt, kwargs.get("instructions")) - if "instructions" in kwargs: - prompt = kwargs.pop("instructions") + "\n\n" + prompt - - def is_base_cohere_chat(func): - try: - return ( - func.__closure__[1].cell_contents.__func__.__qualname__ - == "BaseCohere.chat" - ) - except (AttributeError, IndexError): - return False - - # TODO: When cohere totally gets rid of `generate`, - # remove this cond and the final return - if is_base_cohere_chat(client_callable): - trace_operation( - input_mime_type="application/json", - input_value={**kwargs, "message": prompt, "args": args, "model": model}, - ) - - trace_llm_call( - input_messages=trace_input_messages, - invocation_parameters={**kwargs, "message": prompt, "model": model}, - ) - cohere_response = client_callable( - message=prompt, model=model, *args, **kwargs - ) - trace_operation( - output_mime_type="application/json", output_value=cohere_response - ) - trace_llm_call( - output_messages=[{"role": "assistant", "content": cohere_response.text}] - ) - return LLMResponse( - output=cohere_response.text, - ) - - trace_operation( - input_mime_type="application/json", - input_value={**kwargs, "prompt": prompt, "args": args, "model": model}, - ) - - trace_llm_call( - input_messages=trace_input_messages, - invocation_parameters={**kwargs, "prompt": prompt, "model": model}, - ) - cohere_response = client_callable(prompt=prompt, model=model, *args, **kwargs) - trace_operation( - output_mime_type="application/json", output_value=cohere_response - ) - trace_llm_call( - output_messages=[{"role": "assistant", "content": cohere_response[0].text}] - ) - return LLMResponse( - output=cohere_response[0].text, - ) - - -class AnthropicCallable(PromptCallableBase): - def _invoke_llm( - self, - prompt: str, - client_callable: Any, - model: str = "claude-instant-1", - max_tokens_to_sample: int = 100, - *args, - **kwargs, - ) -> LLMResponse: - """Wrapper for Anthropic Completions. - - To use Anthropic for guardrails, do - ``` - client = anthropic.Anthropic(api_key=...) - - raw_llm_response, validated_response = guard( - client, - model="claude-2", - max_tokens_to_sample=200, - prompt_params={...}, - ... - ``` - """ - warnings.warn( - "This callable is deprecated in favor of passing " - "no callable and the model argument which utilizes LiteLLM" - "for example guard(model='claude-3-opus-20240229', messages=[...], ...)", - DeprecationWarning, - ) - try: - import anthropic - except ImportError: - raise PromptCallableException( - "The `anthropic` package is not installed. " - "Install with `pip install anthropic`" - ) - - trace_input_messages = chat_prompt(prompt, kwargs.get("instructions")) - if "instructions" in kwargs: - prompt = kwargs.pop("instructions") + "\n\n" + prompt - - anthropic_prompt = f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}" - - trace_operation( - input_mime_type="application/json", - input_value={ - **kwargs, - "model": model, - "prompt": anthropic_prompt, - "max_tokens_to_sample": max_tokens_to_sample, - "args": args, - }, - ) - - trace_llm_call( - input_messages=trace_input_messages, - invocation_parameters={ - **kwargs, - "model": model, - "prompt": anthropic_prompt, - "max_tokens_to_sample": max_tokens_to_sample, - }, - ) - - anthropic_response = client_callable( - model=model, - prompt=anthropic_prompt, - max_tokens_to_sample=max_tokens_to_sample, - *args, - **kwargs, - ) - trace_operation( - output_mime_type="application/json", output_value=anthropic_response - ) - trace_llm_call( - output_messages=[ - {"role": "assistant", "content": anthropic_response.completion} - ] - ) - return LLMResponse(output=anthropic_response.completion) - - class LiteLLMCallable(PromptCallableBase): def _invoke_llm( self, @@ -815,7 +651,6 @@ def get_llm_ask( ) -> Optional[PromptCallableBase]: if "temperature" not in kwargs: kwargs.update({"temperature": 0}) - try: from litellm import completion @@ -824,11 +659,6 @@ def get_llm_ask( except ImportError: pass - if llm_api == get_static_openai_create_func(): - return OpenAICallable(*args, **kwargs) - if llm_api == get_static_openai_chat_create_func(): - return OpenAIChatCallable(*args, **kwargs) - try: import manifest # noqa: F401 # type: ignore @@ -837,28 +667,6 @@ def get_llm_ask( except ImportError: pass - try: - import cohere # noqa: F401 # type: ignore - - if ( - isinstance(getattr(llm_api, "__self__", None), cohere.Client) - and getattr(llm_api, "__name__", None) == "generate" - ) or getattr(llm_api, "__module__", None) == "cohere.client": - return CohereCallable(*args, client_callable=llm_api, **kwargs) - except ImportError: - pass - - try: - import anthropic.resources # noqa: F401 # type: ignore - - if isinstance( - getattr(llm_api, "__self__", None), - anthropic.resources.completions.Completions, - ): - return AnthropicCallable(*args, client_callable=llm_api, **kwargs) - except ImportError: - pass - try: from transformers import ( # noqa: F401 # type: ignore FlaxPreTrainedModel, diff --git a/guardrails/prompt/__init__.py b/guardrails/prompt/__init__.py index 15df888b1..1a4e2a655 100644 --- a/guardrails/prompt/__init__.py +++ b/guardrails/prompt/__init__.py @@ -1,7 +1,9 @@ from .instructions import Instructions from .prompt import Prompt +from .messages import Messages __all__ = [ "Prompt", "Instructions", + "Messages" ] diff --git a/guardrails/prompt/messages.py b/guardrails/prompt/messages.py index a7adec457..81e7328d0 100644 --- a/guardrails/prompt/messages.py +++ b/guardrails/prompt/messages.py @@ -8,6 +8,7 @@ from guardrails.classes.templating.namespace_template import NamespaceTemplate from guardrails.utils.constants import constants from guardrails.utils.templating_utils import get_template_variables +from guardrails.prompt.prompt import Prompt class Messages: @@ -36,8 +37,12 @@ def __init__( message["content"] = Template(message["content"]).safe_substitute( output_schema=output_schema, xml_output_schema=xml_output_schema ) - else: - self.source = source + + self.source = self._source + + @property + def variable_names(self): + return get_template_variables(messages_string(self)) def format( self, @@ -59,6 +64,9 @@ def format( ) return Messages(formatted_messages) + def __iter__(self): + return iter(self._source) + def substitute_constants(self, text): """Substitute constants in the prompt.""" # Substitute constants by reading the constants file. @@ -73,3 +81,15 @@ def substitute_constants(self, text): text = template.safe_substitute(**mapping) return text + + +def messages_string(messages: Messages) -> str: + messages_copy = "" + for msg in messages: + content = ( + msg["content"].source + if getattr(msg, "content", None) and isinstance(msg["content"], Prompt) + else msg["content"] + ) + messages_copy += content + return messages_copy diff --git a/guardrails/run/async_runner.py b/guardrails/run/async_runner.py index f10f48946..909ff0974 100644 --- a/guardrails/run/async_runner.py +++ b/guardrails/run/async_runner.py @@ -7,19 +7,16 @@ from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions from guardrails.classes.history import Call, Inputs, Iteration, Outputs from guardrails.classes.output_type import OutputTypes -from guardrails.constants import fail_status from guardrails.errors import ValidationError from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase from guardrails.logger import set_scope -from guardrails.prompt import Instructions, Prompt from guardrails.run.runner import Runner -from guardrails.run.utils import msg_history_source, msg_history_string +from guardrails.run.utils import msg_history_string from guardrails.schema.validator import schema_validation from guardrails.types.pydantic import ModelOrListOfModels from guardrails.types.validator import ValidatorMap from guardrails.utils.exception_utils import UserFacingException from guardrails.classes.llm.llm_response import LLMResponse -from guardrails.utils.prompt_utils import preprocess_prompt, prompt_uses_xml from guardrails.actions.reask import NonParseableReAsk, ReAsk from guardrails.utils.telemetry_utils import async_trace @@ -32,9 +29,7 @@ def __init__( num_reasks: int, validation_map: ValidatorMap, *, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, api: Optional[AsyncPromptCallableBase] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, @@ -48,9 +43,7 @@ def __init__( output_schema=output_schema, num_reasks=num_reasks, validation_map=validation_map, - prompt=prompt, - instructions=instructions, - msg_history=msg_history, + messages=messages, api=api, metadata=metadata, output=output, @@ -78,20 +71,11 @@ async def async_run( """ prompt_params = prompt_params or {} try: - # Figure out if we need to include instructions in the prompt. - include_instructions = not ( - self.instructions is None and self.msg_history is None - ) - ( - instructions, - prompt, - msg_history, + messages, output_schema, ) = ( - self.instructions, - self.prompt, - self.msg_history, + self.messages, self.output_schema, ) index = 0 @@ -100,9 +84,7 @@ async def async_run( iteration = await self.async_step( index=index, api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, output_schema=output_schema, output=self.output if index == 0 else None, @@ -115,17 +97,14 @@ async def async_run( # Get new prompt and output schema. ( - prompt, - instructions, + messages, output_schema, - msg_history, ) = self.prepare_to_loop( iteration.reasks, output_schema, parsed_output=iteration.outputs.parsed_output, validated_output=call_log.validation_response, prompt_params=prompt_params, - include_instructions=include_instructions, ) # Log how many times we reasked @@ -157,9 +136,7 @@ async def async_step( call_log: Call, *, api: Optional[AsyncPromptCallableBase], - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, prompt_params: Optional[Dict] = None, output: Optional[str] = None, ) -> Iteration: @@ -168,9 +145,7 @@ async def async_step( inputs = Inputs( llm_api=api, llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=self.num_reasks, metadata=self.metadata, @@ -185,29 +160,21 @@ async def async_step( try: # Prepare: run pre-processing, and input validation. - if output is not None: - instructions = None - prompt = None - msg_history = None + if output: + messages = None else: - instructions, prompt, msg_history = await self.async_prepare( + messages = await self.async_prepare( call_log, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, api=api, attempt_number=index, ) - iteration.inputs.instructions = instructions - iteration.inputs.prompt = prompt - iteration.inputs.msg_history = msg_history + iteration.inputs.messages = messages # Call: run the API. - llm_response = await self.async_call( - instructions, prompt, msg_history, api, output - ) + llm_response = await self.async_call(messages, api, output) iteration.outputs.llm_response_info = llm_response output = llm_response.output @@ -249,9 +216,7 @@ async def async_step( @async_trace(name="call") async def async_call( self, - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]], + messages: Optional[List[Dict]], api: Optional[AsyncPromptCallableBase], output: Optional[str] = None, ) -> LLMResponse: @@ -273,12 +238,8 @@ async def async_call( ) elif api_fn is None: raise ValueError("API or output must be provided.") - elif msg_history: - llm_response = await api_fn(msg_history=msg_history_source(msg_history)) - elif prompt and instructions: - llm_response = await api_fn(prompt.source, instructions=instructions.source) - elif prompt: - llm_response = await api_fn(prompt.source) + elif messages: + llm_response = await api_fn(messages=messages.source) else: llm_response = await api_fn() return llm_response @@ -327,45 +288,30 @@ async def async_prepare( call_log: Call, attempt_number: int, *, - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]], + messages: Optional[List[Dict]], prompt_params: Optional[Dict] = None, api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], - ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: + ) -> Tuple[Optional[List[Dict]]]: """Prepare by running pre-processing and input validation. Returns: - The instructions, prompt, and message history. + The messages """ prompt_params = prompt_params or {} - has_prompt_validation = "prompt" in self.validation_map - has_instructions_validation = "instructions" in self.validation_map - has_msg_history_validation = "msg_history" in self.validation_map - if msg_history: - if has_prompt_validation or has_instructions_validation: - raise UserFacingException( - ValueError( - "Prompt and instructions validation are " - "not supported when using message history." - ) - ) - - prompt, instructions = None, None - + if messages: # Runner.prepare_msg_history - formatted_msg_history = [] + formatted_messages = [] # Format any variables in the message history with the prompt params. - for msg in msg_history: + for msg in messages: msg_copy = copy.deepcopy(msg) msg_copy["content"] = msg_copy["content"].format(**prompt_params) - formatted_msg_history.append(msg_copy) + formatted_messages.append(msg_copy) - if "msg_history" in self.validation_map: - # Runner.validate_msg_history - msg_str = msg_history_string(formatted_msg_history) + if "messages" in self.validation_map: + # Runner.validate_message + msg_str = msg_history_string(formatted_messages) inputs = Inputs( llm_output=msg_str, ) @@ -379,112 +325,21 @@ async def async_prepare( validator_map=self.validation_map, iteration=iteration, disable_tracer=self._disable_tracer, - path="msg_history", + path="message", ) - validated_msg_history = validator_service.post_process_validation( + validated_messages = validator_service.post_process_validation( value, attempt_number, iteration, OutputTypes.STRING ) - validated_msg_history = cast(str, validated_msg_history) + validated_messages = cast(str, validated_messages) - iteration.outputs.validation_response = validated_msg_history - if isinstance(validated_msg_history, ReAsk): + iteration.outputs.validation_response = validated_messages + if isinstance(validated_messages, ReAsk): raise ValidationError( - f"Message history validation failed: " - f"{validated_msg_history}" + f"Messages validation failed: " f"{validated_messages}" ) - if validated_msg_history != msg_str: + if validated_messages != msg_str: raise ValidationError("Message history validation failed") - elif prompt is not None: - if has_msg_history_validation: - raise UserFacingException( - ValueError( - "Message history validation is " - "not supported when using prompt/instructions." - ) - ) - msg_history = None - - use_xml = prompt_uses_xml(prompt._source) - # Runner.prepare_prompt - prompt = prompt.format(**prompt_params) - - # TODO(shreya): should there be any difference - # to parsing params for prompt? - if instructions is not None and isinstance(instructions, Instructions): - instructions = instructions.format(**prompt_params) - - instructions, prompt = preprocess_prompt( - prompt_callable=api, # type: ignore - instructions=instructions, - prompt=prompt, - output_type=self.output_type, - use_xml=use_xml, - ) - - # validate prompt - if "prompt" in self.validation_map and prompt is not None: - # Runner.validate_prompt - inputs = Inputs( - llm_output=prompt.source, - ) - iteration = Iteration( - call_id=call_log.id, index=attempt_number, inputs=inputs - ) - call_log.iterations.insert(0, iteration) - value, _metadata = await validator_service.async_validate( - value=prompt.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="prompt", - ) - validated_prompt = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - - iteration.outputs.validation_response = validated_prompt - if isinstance(validated_prompt, ReAsk): - raise ValidationError( - f"Prompt validation failed: {validated_prompt}" - ) - elif not validated_prompt or iteration.status == fail_status: - raise ValidationError("Prompt validation failed") - prompt = Prompt(cast(str, validated_prompt)) - - # validate instructions - if "instructions" in self.validation_map and instructions is not None: - # Runner.validate_instructions - inputs = Inputs( - llm_output=instructions.source, - ) - iteration = Iteration( - call_id=call_log.id, index=attempt_number, inputs=inputs - ) - call_log.iterations.insert(0, iteration) - value, _metadata = await validator_service.async_validate( - value=instructions.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="instructions", - ) - validated_instructions = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - - iteration.outputs.validation_response = validated_instructions - if isinstance(validated_instructions, ReAsk): - raise ValidationError( - f"Instructions validation failed: {validated_instructions}" - ) - elif not validated_instructions or iteration.status == fail_status: - raise ValidationError("Instructions validation failed") - instructions = Instructions(cast(str, validated_instructions)) else: - raise UserFacingException( - ValueError("'prompt' or 'msg_history' must be provided.") - ) + raise UserFacingException(ValueError("'messages' must be provided.")) - return instructions, prompt, msg_history + return messages diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 5f6206766..d20e6967a 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -23,7 +23,6 @@ PromptCallableBase, ) from guardrails.logger import set_scope -from guardrails.prompt import Instructions, Prompt from guardrails.run import StreamRunner from guardrails.run.async_runner import AsyncRunner @@ -35,14 +34,10 @@ async def async_run( prompt_params = prompt_params or {} ( - instructions, - prompt, - msg_history, + messages, output_schema, ) = ( - self.instructions, - self.prompt, - self.msg_history, + self.messages, self.output_schema, ) @@ -51,9 +46,7 @@ async def async_run( output_schema, call_log, api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, output=self.output, ) @@ -70,9 +63,7 @@ async def async_step( call_log: Call, *, api: Optional[AsyncPromptCallableBase], - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, prompt_params: Optional[Dict] = None, output: Optional[str] = None, ) -> AsyncIterable[ValidationOutcome]: @@ -80,9 +71,7 @@ async def async_step( inputs = Inputs( llm_api=api, llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=self.num_reasks, metadata=self.metadata, @@ -95,28 +84,20 @@ async def async_step( ) set_scope(str(id(iteration))) call_log.iterations.push(iteration) - if output is not None: - instructions = None - prompt = None - msg_history = None + if output: + messages = None else: - instructions, prompt, msg_history = await self.async_prepare( + messages = await self.async_prepare( call_log, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, api=api, attempt_number=index, ) - iteration.inputs.prompt = prompt - iteration.inputs.instructions = instructions - iteration.inputs.msg_history = msg_history + iteration.inputs.messages = messages - llm_response = await self.async_call( - instructions, prompt, msg_history, api, output - ) + llm_response = await self.async_call(messages, api, output) iteration.outputs.llm_response_info = llm_response stream_output = llm_response.async_stream_output if not stream_output: diff --git a/guardrails/run/runner.py b/guardrails/run/runner.py index caab80ea4..4dc6c6f4e 100644 --- a/guardrails/run/runner.py +++ b/guardrails/run/runner.py @@ -1,6 +1,6 @@ import copy from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from guardrails import validator_service @@ -8,15 +8,15 @@ from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions from guardrails.classes.history import Call, Inputs, Iteration, Outputs from guardrails.classes.output_type import OutputTypes -from guardrails.constants import fail_status from guardrails.errors import ValidationError from guardrails.llm_providers import ( AsyncPromptCallableBase, PromptCallableBase, ) from guardrails.logger import set_scope -from guardrails.prompt import Instructions, Prompt -from guardrails.run.utils import msg_history_source, msg_history_string +from guardrails.prompt import Prompt +from guardrails.prompt.messages import Messages +from guardrails.run.utils import messages_string from guardrails.schema.rail_schema import json_schema_to_rail_output from guardrails.schema.validator import schema_validation from guardrails.types import ModelOrListOfModels, ValidatorMap, MessageHistory @@ -29,9 +29,7 @@ prune_extra_keys, ) from guardrails.utils.prompt_utils import ( - preprocess_prompt, prompt_content_for_schema, - prompt_uses_xml, ) from guardrails.actions.reask import NonParseableReAsk, ReAsk, introspect from guardrails.utils.telemetry_utils import trace @@ -61,9 +59,7 @@ class Runner: metadata: Dict[str, Any] # LLM Inputs - prompt: Optional[Prompt] = None - instructions: Optional[Instructions] = None - msg_history: Optional[List[Dict[str, Union[Prompt, str]]]] = None + messages: Optional[List[Dict[str, Union[Prompt, str]]]] = None base_model: Optional[ModelOrListOfModels] exec_options: Optional[GuardExecutionOptions] @@ -77,7 +73,7 @@ class Runner: disable_tracer: Optional[bool] = True # QUESTION: Are any of these init args actually necessary for initialization? - # ANSWER: _Maybe_ prompt, instructions, and msg_history for Prompt initialization + # ANSWER: _Maybe_ messages for Prompt initialization # but even that can happen at execution time. # TODO: In versions >=0.6.x, remove this class and just execute a Guard functionally def __init__( @@ -87,9 +83,7 @@ def __init__( num_reasks: int, validation_map: ValidatorMap, *, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, api: Optional[PromptCallableBase] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, @@ -113,34 +107,19 @@ def __init__( xml_output_schema = json_schema_to_rail_output( json_schema=output_schema, validator_map=validation_map ) - if prompt: - self.exec_options.prompt = prompt - self.prompt = Prompt( - prompt, - output_schema=stringified_output_schema, - xml_output_schema=xml_output_schema, - ) - - if instructions: - self.exec_options.instructions = instructions - self.instructions = Instructions( - instructions, - output_schema=stringified_output_schema, - xml_output_schema=xml_output_schema, - ) - if msg_history: - self.exec_options.msg_history = msg_history - msg_history_copy = [] - for msg in msg_history: + if messages: + self.exec_options.messages = messages + messages_copy = [] + for msg in messages: msg_copy = copy.deepcopy(msg) msg_copy["content"] = Prompt( msg_copy["content"], output_schema=stringified_output_schema, xml_output_schema=xml_output_schema, ) - msg_history_copy.append(msg_copy) - self.msg_history = msg_history_copy + messages_copy.append(msg_copy) + self.messages = Messages(source=messages_copy) self.base_model = base_model @@ -171,35 +150,24 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call """ prompt_params = prompt_params or {} try: - # Figure out if we need to include instructions in the prompt. - include_instructions = not ( - self.instructions is None and self.msg_history is None - ) - # NOTE: At first glance this seems gratuitous, # but these local variables are reassigned after # calling self.prepare_to_loop ( - instructions, - prompt, - msg_history, + messages, output_schema, ) = ( - self.instructions, - self.prompt, - self.msg_history, + self.messages, self.output_schema, ) - + print("===runner self messages", self.messages) index = 0 for index in range(self.num_reasks + 1): # Run a single step. iteration = self.step( index=index, api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, output_schema=output_schema, output=self.output if index == 0 else None, @@ -210,16 +178,13 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call if not self.do_loop(index, iteration.reasks): break - # Get new prompt and output schema. - (prompt, instructions, output_schema, msg_history) = ( - self.prepare_to_loop( - iteration.reasks, - output_schema, - parsed_output=iteration.outputs.parsed_output, - validated_output=call_log.validation_response, - prompt_params=prompt_params, - include_instructions=include_instructions, - ) + # Get new messages and output schema. + (messages, output_schema) = self.prepare_to_loop( + iteration.reasks, + output_schema, + parsed_output=iteration.outputs.parsed_output, + validated_output=call_log.validation_response, + prompt_params=prompt_params, ) # Log how many times we reasked @@ -231,7 +196,6 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call is_parent=False, # This span has no children has_parent=True, # This span has a parent ) - except UserFacingException as e: # Because Pydantic v1 doesn't respect property setters call_log.exception = e.original_exception @@ -250,20 +214,17 @@ def step( call_log: Call, *, api: Optional[PromptCallableBase], - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]] = None, + messages: Optional[List[Dict]] = None, prompt_params: Optional[Dict] = None, output: Optional[str] = None, ) -> Iteration: """Run a full step.""" + print("==== step input messages", messages) prompt_params = prompt_params or {} inputs = Inputs( llm_api=api, llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=self.num_reasks, metadata=self.metadata, @@ -278,27 +239,21 @@ def step( try: # Prepare: run pre-processing, and input validation. - if output is not None: - instructions = None - prompt = None - msg_history = None + if output: + messages = None else: - instructions, prompt, msg_history = self.prepare( + messages = self.prepare( call_log, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, api=api, attempt_number=index, ) - iteration.inputs.instructions = instructions - iteration.inputs.prompt = prompt - iteration.inputs.msg_history = msg_history + iteration.inputs.messages = messages # Call: run the API. - llm_response = self.call(instructions, prompt, msg_history, api, output) + llm_response = self.call(messages, api, output) iteration.outputs.llm_response_info = llm_response raw_output = llm_response.output @@ -335,10 +290,10 @@ def step( raise e return iteration - def validate_msg_history( - self, call_log: Call, msg_history: MessageHistory, attempt_number: int + def validate_messages( + self, call_log: Call, messages: MessageHistory, attempt_number: int ) -> None: - msg_str = msg_history_string(msg_history) + msg_str = messages_string(messages) inputs = Inputs( llm_output=msg_str, ) @@ -350,189 +305,69 @@ def validate_msg_history( validator_map=self.validation_map, iteration=iteration, disable_tracer=self._disable_tracer, - path="msg_history", + path="messages", ) - validated_msg_history = validator_service.post_process_validation( + validated_messages = validator_service.post_process_validation( value, attempt_number, iteration, OutputTypes.STRING ) - iteration.outputs.validation_response = validated_msg_history - if isinstance(validated_msg_history, ReAsk): + iteration.outputs.validation_response = validated_messages + if isinstance(validated_messages, ReAsk): raise ValidationError( - f"Message history validation failed: " f"{validated_msg_history}" + f"Message validation failed: " f"{validated_messages}" ) - if validated_msg_history != msg_str: - raise ValidationError("Message history validation failed") + if validated_messages != msg_str: + raise ValidationError("Message validation failed") - def prepare_msg_history( + def prepare_messages( self, call_log: Call, - msg_history: MessageHistory, + messages: MessageHistory, prompt_params: Dict, attempt_number: int, ) -> MessageHistory: - formatted_msg_history: MessageHistory = [] + formatted_messages: MessageHistory = [] # Format any variables in the message history with the prompt params. - for msg in msg_history: + for msg in messages: msg_copy = copy.deepcopy(msg) msg_copy["content"] = msg_copy["content"].format(**prompt_params) - formatted_msg_history.append(msg_copy) + formatted_messages.append(msg_copy) # validate msg_history - if "msg_history" in self.validation_map: - self.validate_msg_history(call_log, formatted_msg_history, attempt_number) - - return formatted_msg_history - - def validate_prompt(self, call_log: Call, prompt: Prompt, attempt_number: int): - inputs = Inputs( - llm_output=prompt.source, - ) - iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs) - call_log.iterations.insert(0, iteration) - value, _metadata = validator_service.validate( - value=prompt.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="prompt", - ) - - validated_prompt = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - - iteration.outputs.validation_response = validated_prompt - - if isinstance(validated_prompt, ReAsk): - raise ValidationError(f"Prompt validation failed: {validated_prompt}") - elif not validated_prompt or iteration.status == fail_status: - raise ValidationError("Prompt validation failed") - return Prompt(cast(str, validated_prompt)) - - def validate_instructions( - self, call_log: Call, instructions: Instructions, attempt_number: int - ): - inputs = Inputs( - llm_output=instructions.source, - ) - iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs) - call_log.iterations.insert(0, iteration) - value, _metadata = validator_service.validate( - value=instructions.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="instructions", - ) - validated_instructions = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) + if "messages" in self.validation_map: + self.validate_messages(call_log, formatted_messages, attempt_number) - iteration.outputs.validation_response = validated_instructions - if isinstance(validated_instructions, ReAsk): - raise ValidationError( - f"Instructions validation failed: {validated_instructions}" - ) - elif not validated_instructions or iteration.status == fail_status: - raise ValidationError("Instructions validation failed") - return Instructions(cast(str, validated_instructions)) - - def prepare_prompt( - self, - call_log: Call, - instructions: Optional[Instructions], - prompt: Prompt, - prompt_params: Dict, - api: Union[PromptCallableBase, AsyncPromptCallableBase], - attempt_number: int, - ): - use_xml = prompt_uses_xml(self.prompt._source) if self.prompt else False - prompt = prompt.format(**prompt_params) - - # TODO(shreya): should there be any difference - # to parsing params for prompt? - if instructions is not None and isinstance(instructions, Instructions): - instructions = instructions.format(**prompt_params) - - instructions, prompt = preprocess_prompt( - prompt_callable=api, - instructions=instructions, - prompt=prompt, - output_type=self.output_type, - use_xml=use_xml, - ) - - # validate prompt - if "prompt" in self.validation_map and prompt is not None: - prompt = self.validate_prompt(call_log, prompt, attempt_number) - - # validate instructions - if "instructions" in self.validation_map and instructions is not None: - instructions = self.validate_instructions( - call_log, instructions, attempt_number - ) - - return instructions, prompt + return formatted_messages def prepare( self, call_log: Call, attempt_number: int, *, - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[MessageHistory], + messages: Optional[MessageHistory], prompt_params: Optional[Dict] = None, api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], - ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[MessageHistory]]: + ) -> Tuple[Optional[MessageHistory]]: """Prepare by running pre-processing and input validation. Returns: - The instructions, prompt, and message history. + The messages. """ prompt_params = prompt_params or {} if api is None: raise UserFacingException(ValueError("API must be provided.")) - has_prompt_validation = "prompt" in self.validation_map - has_instructions_validation = "instructions" in self.validation_map - has_msg_history_validation = "msg_history" in self.validation_map - if msg_history: - if has_prompt_validation or has_instructions_validation: - raise UserFacingException( - ValueError( - "Prompt and instructions validation are " - "not supported when using message history." - ) - ) - prompt, instructions = None, None - msg_history = self.prepare_msg_history( - call_log, msg_history, prompt_params, attempt_number - ) - elif prompt is not None: - if has_msg_history_validation: - raise UserFacingException( - ValueError( - "Message history validation is " - "not supported when using prompt/instructions." - ) - ) - msg_history = None - instructions, prompt = self.prepare_prompt( - call_log, instructions, prompt, prompt_params, api, attempt_number + if messages: + messages = self.prepare_messages( + call_log, messages, prompt_params, attempt_number ) - return instructions, prompt, msg_history + return messages @trace(name="call") def call( self, - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[MessageHistory], + messages: Optional[MessageHistory], api: Optional[PromptCallableBase], output: Optional[str] = None, ) -> LLMResponse: @@ -542,7 +377,7 @@ def call( 2. Convert the response string to a dict, 3. Log the output """ - + print("====messages", messages) # If the API supports a base model, pass it in. api_fn = api if api is not None: @@ -554,12 +389,8 @@ def call( llm_response = LLMResponse(output=output) elif api_fn is None: raise ValueError("API or output must be provided.") - elif msg_history: - llm_response = api_fn(msg_history=msg_history_source(msg_history)) - elif prompt and instructions: - llm_response = api_fn(prompt.source, instructions=instructions.source) - elif prompt: - llm_response = api_fn(prompt.source) + elif messages: + llm_response = api_fn(messages=messages.source) else: llm_response = api_fn() @@ -635,16 +466,10 @@ def prepare_to_loop( parsed_output: Optional[Union[str, List, Dict, ReAsk]] = None, validated_output: Optional[Union[str, List, Dict, ReAsk]] = None, prompt_params: Optional[Dict] = None, - include_instructions: bool = False, - ) -> Tuple[ - Prompt, - Optional[Instructions], - Dict[str, Any], - Optional[List[Dict]], - ]: + ) -> Tuple[Dict[str, Any], Optional[List[Dict]]]: """Prepare to loop again.""" prompt_params = prompt_params or {} - output_schema, prompt, instructions = get_reask_setup( + output_schema, messages = get_reask_setup( output_type=self.output_type, output_schema=output_schema, validation_map=self.validation_map, @@ -655,8 +480,5 @@ def prepare_to_loop( prompt_params=prompt_params, exec_options=self.exec_options, ) - if not include_instructions: - instructions = None - # todo add messages support - msg_history = None - return prompt, instructions, output_schema, msg_history + + return output_schema, messages diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index d53f3dc2e..aa412a458 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -9,7 +9,6 @@ OpenAIChatCallable, PromptCallableBase, ) -from guardrails.prompt import Instructions, Prompt from guardrails.run.runner import Runner from guardrails.utils.parsing_utils import ( coerce_types, @@ -41,32 +40,20 @@ def __call__( Returns: The Call log for this run. """ - # This is only used during ReAsks and ReAsks - # are not yet supported for streaming. - # Figure out if we need to include instructions in the prompt. - # include_instructions = not ( - # self.instructions is None and self.msg_history is None - # ) prompt_params = prompt_params or {} ( - instructions, - prompt, - msg_history, + messages, output_schema, ) = ( - self.instructions, - self.prompt, - self.msg_history, + self.messages, self.output_schema, ) return self.step( index=0, api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, output_schema=output_schema, output=self.output, @@ -78,9 +65,7 @@ def step( self, index: int, api: Optional[PromptCallableBase], - instructions: Optional[Instructions], - prompt: Optional[Prompt], - msg_history: Optional[List[Dict]], + messages: Optional[List[Dict]], prompt_params: Dict, output_schema: Dict[str, Any], call_log: Call, @@ -90,9 +75,7 @@ def step( inputs = Inputs( llm_api=api, llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=self.num_reasks, metadata=self.metadata, @@ -106,27 +89,21 @@ def step( call_log.iterations.push(iteration) # Prepare: run pre-processing, and input validation. - if output is not None: - instructions = None - prompt = None - msg_history = None + if output: + messages = None else: - instructions, prompt, msg_history = self.prepare( + messages = self.prepare( call_log, index, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, api=api, ) - iteration.inputs.prompt = prompt - iteration.inputs.instructions = instructions - iteration.inputs.msg_history = msg_history + iteration.inputs.messages = messages # Call: run the API that returns a generator wrapped in LLMResponse - llm_response = self.call(instructions, prompt, msg_history, api, output) + llm_response = self.call(messages, api, output) iteration.outputs.llm_response_info = llm_response diff --git a/guardrails/run/utils.py b/guardrails/run/utils.py index cdd9de4d2..e6979b4c6 100644 --- a/guardrails/run/utils.py +++ b/guardrails/run/utils.py @@ -2,6 +2,7 @@ from typing import Dict, cast from guardrails.prompt.prompt import Prompt +from guardrails.prompt.messages import Messages from guardrails.types.inputs import MessageHistory @@ -29,3 +30,15 @@ def msg_history_string(msg_history: MessageHistory) -> str: ) msg_history_copy += content return msg_history_copy + + +def messages_string(messages: Messages) -> str: + messages_copy = "" + for msg in messages: + content = ( + msg["content"].source + if isinstance(msg["content"], Prompt) + else msg["content"] + ) + messages_copy += content + return messages_copy diff --git a/guardrails/schema/rail_schema.py b/guardrails/schema/rail_schema.py index 2aaa04ec4..ab8edbbe7 100644 --- a/guardrails/schema/rail_schema.py +++ b/guardrails/schema/rail_schema.py @@ -1,5 +1,4 @@ import jsonref -import warnings from dataclasses import dataclass from string import Template from typing import Any, Callable, Dict, List, Optional, Tuple, cast @@ -21,14 +20,7 @@ ### RAIL to JSON Schema ### -STRING_TAGS = [ - "instructions", - "prompt", - "reask_instructions", - "reask_prompt", - "messages", - "reask_messages", -] +STRING_TAGS = ["messages", "reask_messages"] def parse_on_fail_handlers(element: _Element) -> Dict[str, OnFailAction]: @@ -392,49 +384,6 @@ def rail_string_to_schema(rail_string: str) -> ProcessedSchema: ' "string", "object", or "list"' ) - # Parse instructions for the LLM. These are optional but if given, - # LLMs can use them to improve their output. Commonly these are - # prepended to the prompt. - instructions_tag = rail_xml.find("instructions") - if instructions_tag is not None: - parse_element(instructions_tag, processed_schema, "instructions") - processed_schema.exec_opts.instructions = instructions_tag.text - warnings.warn( - "The instructions tag has been deprecated" - " in favor of messages. Please use messages instead.", - DeprecationWarning, - ) - - # Load - prompt_tag = rail_xml.find("prompt") - if prompt_tag is not None: - parse_element(prompt_tag, processed_schema, "prompt") - processed_schema.exec_opts.prompt = prompt_tag.text - warnings.warn( - "The prompt tag has been deprecated" - " in favor of messages. Please use messages instead.", - DeprecationWarning, - ) - - # If reasking prompt and instructions are provided, add them to the schema. - reask_prompt = rail_xml.find("reask_prompt") - if reask_prompt is not None: - processed_schema.exec_opts.reask_prompt = reask_prompt.text - warnings.warn( - "The reask_prompt tag has been deprecated" - " in favor of reask_messages. Please use reask_messages instead.", - DeprecationWarning, - ) - - reask_instructions = rail_xml.find("reask_instructions") - if reask_instructions is not None: - processed_schema.exec_opts.reask_instructions = reask_instructions.text - warnings.warn( - "The reask_instructions tag has been deprecated" - " in favor of reask_messages. Please use reask_messages instead.", - DeprecationWarning, - ) - messages = rail_xml.find("messages") if messages is not None: extracted_messages = [] @@ -455,7 +404,7 @@ def rail_string_to_schema(rail_string: str) -> ProcessedSchema: role = message.attrib.get("role") content = message.text extracted_reask_messages.append({"role": role, "content": content}) - processed_schema.exec_opts.messages = extracted_reask_messages + processed_schema.exec_opts.reask_messages = extracted_reask_messages return processed_schema diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index a74c58fe5..38da58c2c 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -199,6 +199,8 @@ class MinSentenceLengthNoOp(BaseModel): ${gr.complete_json_suffix} """ +MESSAGES=[{"role": "user", "content": PROMPT}] + JSON_LLM_CHUNKS = [ '{"statement":', ' "I am DOING', @@ -212,19 +214,19 @@ class MinSentenceLengthNoOp(BaseModel): "guard, expected_validated_output", [ ( - gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseNoop, messages=MESSAGES), expected_noop_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseFix, messages=MESSAGES), expected_fix_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseFilter, messages=MESSAGES), expected_filter_refrain_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, messages=MESSAGES), expected_filter_refrain_output, ), ], @@ -268,19 +270,19 @@ def test_streaming_with_openai_callable( "guard, expected_validated_output", [ ( - gd.Guard.from_pydantic(output_class=LowerCaseNoop, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseNoop, messages=MESSAGES), expected_noop_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseFix, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseFix, messages=MESSAGES), expected_fix_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseFilter, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseFilter, messages=MESSAGES), expected_filter_refrain_output, ), ( - gd.Guard.from_pydantic(output_class=LowerCaseRefrain, prompt=PROMPT), + gd.Guard.from_pydantic(output_class=LowerCaseRefrain, messages=MESSAGES), expected_filter_refrain_output, ), ], @@ -345,7 +347,7 @@ def test_streaming_with_openai_chat_callable( validators=[ MinSentenceLengthValidator(26, 30, on_fail=OnFailAction.NOOP) ], - prompt=STR_PROMPT, + messages=MESSAGES, ), # each value is a tuple # first is expected text inside span diff --git a/tests/unit_tests/actions/test_reask.py b/tests/unit_tests/actions/test_reask.py index ccf6e77d7..5e629551a 100644 --- a/tests/unit_tests/actions/test_reask.py +++ b/tests/unit_tests/actions/test_reask.py @@ -542,13 +542,12 @@ def test_get_reask_prompt( output_schema = processed_schema.json_schema exec_options = GuardExecutionOptions( # Use an XML constant to make existing test cases pass - prompt="${gr.complete_xml_suffix_v3}" + messages=[{"role": "system", "content": "${gr.complete_xml_suffix_v3}"}] ) ( reask_schema, - reask_prompt, - reask_instructions, + reask_messages, ) = get_reask_setup( output_type, output_schema, @@ -569,9 +568,8 @@ def test_get_reask_prompt( # json.dumps(json_example, indent=2), ) - assert reask_prompt.source == expected_prompt - - assert reask_instructions.source == expected_instructions + assert str(reask_messages.source[0]["content"]) == expected_instructions + assert str(reask_messages.source[1]["content"]) == expected_prompt ### FIXME: Implement once Field Level ReAsk is implemented w/ JSON schema ### diff --git a/tests/unit_tests/classes/history/test_call.py b/tests/unit_tests/classes/history/test_call.py index f1f2eb034..a9ce9706e 100644 --- a/tests/unit_tests/classes/history/test_call.py +++ b/tests/unit_tests/classes/history/test_call.py @@ -8,6 +8,7 @@ from guardrails.llm_providers import ArbitraryCallable from guardrails.prompt.instructions import Instructions from guardrails.prompt.prompt import Prompt +from guardrails.prompt.messages import Messages from guardrails.classes.llm.llm_response import LLMResponse from guardrails.classes.validation.validator_logs import ValidatorLogs from guardrails.actions.reask import ReAsk @@ -19,13 +20,11 @@ def test_empty_initialization(): assert call.iterations == Stack() assert call.inputs == CallInputs() - assert call.prompt is None assert call.prompt_params is None - assert call.compiled_prompt is None assert call.reask_prompts == Stack() - assert call.instructions is None - assert call.compiled_instructions is None - assert call.reask_instructions == Stack() + assert call.messages is None + assert call.compiled_messages is None + assert len(call.reask_messages) == 0 assert call.logs == Stack() assert call.tokens_consumed is None assert call.prompt_tokens_consumed is None @@ -69,8 +68,12 @@ def custom_llm(): # First Iteration Inputs iter_llm_api = ArbitraryCallable(llm_api=llm_api) llm_output = "Hello there!" - instructions = Instructions(source=instructions) - iter_prompt = Prompt(source=prompt) + + messages = Messages(source=[ + {"role": "system", "content": instructions}, + {"role": "user", "content": prompt}, + ]) + num_reasks = 0 metadata = {"some_meta_data": "doesn't actually matter"} full_schema_reask = False @@ -78,8 +81,7 @@ def custom_llm(): inputs = Inputs( llm_api=iter_llm_api, llm_output=llm_output, - instructions=instructions, - prompt=iter_prompt, + messages=messages, prompt_params=prompt_params, num_reasks=num_reasks, metadata=metadata, @@ -121,13 +123,14 @@ def custom_llm(): call_id="mock-call", index=0, inputs=inputs, outputs=first_outputs ) - second_iter_prompt = Prompt(source="That wasn't quite right. Try again.") + second_iter_messages = Messages(source=[ + {"role":"system", "content":"That wasn't quite right. Try again."} + ]) second_inputs = Inputs( llm_api=iter_llm_api, llm_output=llm_output, - instructions=instructions, - prompt=second_iter_prompt, + messages=second_iter_messages, num_reasks=num_reasks, metadata=metadata, full_schema_reask=full_schema_reask, @@ -171,11 +174,10 @@ def custom_llm(): assert call.prompt == prompt assert call.prompt_params == prompt_params - assert call.compiled_prompt == "Respond with a friendly greeting." - assert call.reask_prompts == Stack(second_iter_prompt.source) - assert call.instructions == instructions.source - assert call.compiled_instructions == instructions.source - assert call.reask_instructions == Stack(instructions.source) + # TODO FIX this + # assert call.messages == messages.source + assert call.compiled_messages[1]["content"] == "Respond with a friendly greeting." + assert call.reask_messages == Stack(second_iter_messages.source) # TODO: Test this in the integration tests assert call.logs == [] diff --git a/tests/unit_tests/classes/history/test_call_inputs.py b/tests/unit_tests/classes/history/test_call_inputs.py index 74191683f..407f242b4 100644 --- a/tests/unit_tests/classes/history/test_call_inputs.py +++ b/tests/unit_tests/classes/history/test_call_inputs.py @@ -6,14 +6,12 @@ def test_empty_initialization(): # Overrides and additional properties assert call_inputs.llm_api is None - assert call_inputs.prompt is None - assert call_inputs.instructions is None assert call_inputs.args == [] assert call_inputs.kwargs == {} # Inherited properties assert call_inputs.llm_output is None - assert call_inputs.msg_history is None + assert call_inputs.messages is None assert call_inputs.prompt_params is None assert call_inputs.num_reasks is None assert call_inputs.metadata is None diff --git a/tests/unit_tests/classes/history/test_inputs.py b/tests/unit_tests/classes/history/test_inputs.py index 4ef221c48..830f11e01 100644 --- a/tests/unit_tests/classes/history/test_inputs.py +++ b/tests/unit_tests/classes/history/test_inputs.py @@ -1,7 +1,6 @@ from guardrails.classes.history.inputs import Inputs from guardrails.llm_providers import OpenAICallable -from guardrails.prompt.instructions import Instructions -from guardrails.prompt.prompt import Prompt +from guardrails.prompt.messages import Messages # Guard against regressions in pydantic BaseModel @@ -10,9 +9,7 @@ def test_empty_initialization(): assert inputs.llm_api is None assert inputs.llm_output is None - assert inputs.instructions is None - assert inputs.prompt is None - assert inputs.msg_history is None + assert inputs.messages is None assert inputs.prompt_params is None assert inputs.num_reasks is None assert inputs.metadata is None @@ -22,11 +19,12 @@ def test_empty_initialization(): def test_non_empty_initialization(): llm_api = OpenAICallable(text="Respond with a greeting.") llm_output = "Hello there!" - instructions = Instructions(source="You are a greeting bot.") - prompt = Prompt(source="Respond with a ${greeting_type} greeting.") - msg_history = [ - {"some_key": "doesn't actually matter because this isn't that strongly typed"} - ] + messages = Messages( + source=[ + {"role": "system", "content": "You are a greeting bot."}, + {"role": "user", "content": "Respond with a ${greeting_type} greeting."}, + ] + ) prompt_params = {"greeting_type": "friendly"} num_reasks = 0 metadata = {"some_meta_data": "doesn't actually matter"} @@ -35,9 +33,7 @@ def test_non_empty_initialization(): inputs = Inputs( llm_api=llm_api, llm_output=llm_output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=num_reasks, metadata=metadata, @@ -48,12 +44,8 @@ def test_non_empty_initialization(): assert inputs.llm_api == llm_api assert inputs.llm_output is not None assert inputs.llm_output == llm_output - assert inputs.instructions is not None - assert inputs.instructions == instructions - assert inputs.prompt is not None - assert inputs.prompt == prompt - assert inputs.msg_history is not None - assert inputs.msg_history == msg_history + assert inputs.messages is not None + assert inputs.messages == messages assert inputs.prompt_params is not None assert inputs.prompt_params == prompt_params assert inputs.num_reasks is not None diff --git a/tests/unit_tests/classes/history/test_iteration.py b/tests/unit_tests/classes/history/test_iteration.py index 38bc05d73..802e321a1 100644 --- a/tests/unit_tests/classes/history/test_iteration.py +++ b/tests/unit_tests/classes/history/test_iteration.py @@ -6,6 +6,7 @@ from guardrails.llm_providers import OpenAICallable from guardrails.prompt.instructions import Instructions from guardrails.prompt.prompt import Prompt +from guardrails.prompt.messages import Messages from guardrails.classes.llm.llm_response import LLMResponse from guardrails.classes.validation.validator_logs import ValidatorLogs from guardrails.actions.reask import FieldReAsk @@ -40,11 +41,10 @@ def test_non_empty_initialization(): # Inputs llm_api = OpenAICallable(text="Respond with a greeting.") llm_output = "Hello there!" - instructions = Instructions(source="You are a greeting bot.") - prompt = Prompt(source="Respond with a ${greeting_type} greeting.") - msg_history = [ - {"some_key": "doesn't actually matter because this isn't that strongly typed"} - ] + messages = Messages(source=[ + {"role": "system", "content": "You are a greeting bot."}, + {"role": "user", "content": "Respond with a ${greeting_type} greeting."} + ]) prompt_params = {"greeting_type": "friendly"} num_reasks = 0 metadata = {"some_meta_data": "doesn't actually matter"} @@ -53,9 +53,7 @@ def test_non_empty_initialization(): inputs = Inputs( llm_api=llm_api, llm_output=llm_output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, + messages=messages, prompt_params=prompt_params, num_reasks=num_reasks, metadata=metadata, diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index f3e951929..7b2916fab 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -521,9 +521,9 @@ def test_validate(): # Should still only use the output validators to validate the output guard: Guard = ( Guard() - .use(OneLine, on="prompt") - .use(LowerCase, on="instructions") - .use(UpperCase, on="msg_history") + .use(OneLine, on="messages") + .use(LowerCase, on="messages") + .use(UpperCase, on="messages") .use(LowerCase, on="output", on_fail=OnFailAction.FIX) .use(TwoWords, on="output") .use(ValidLength, 0, 12, on="output") @@ -547,9 +547,7 @@ def test_validate(): def test_use_and_use_many(): guard: Guard = ( Guard() - .use_many(OneLine(), LowerCase(), on="prompt") - .use(UpperCase, on="instructions") - .use(LowerCase, on="msg_history") + .use_many(OneLine(), LowerCase(), on="messages") .use_many( TwoWords(on_fail=OnFailAction.REASK), ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), @@ -557,20 +555,12 @@ def test_use_and_use_many(): ) ) - # Check schemas for prompt, instructions and msg_history validators - prompt_validators = guard._validator_map.get("prompt", []) + # Check schemas for messages validators + prompt_validators = guard._validator_map.get("messages", []) assert len(prompt_validators) == 2 assert prompt_validators[0].__class__.__name__ == "OneLine" assert prompt_validators[1].__class__.__name__ == "LowerCase" - instructions_validators = guard._validator_map.get("instructions", []) - assert len(instructions_validators) == 1 - assert instructions_validators[0].__class__.__name__ == "UpperCase" - - msg_history_validators = guard._validator_map.get("msg_history", []) - assert len(msg_history_validators) == 1 - assert msg_history_validators[0].__class__.__name__ == "LowerCase" - # Check guard for validators assert len(guard._validators) == 6 @@ -590,9 +580,8 @@ def test_use_and_use_many(): with pytest.warns(UserWarning): guard: Guard = ( Guard() - .use_many(OneLine(), LowerCase(), on="prompt") - .use(UpperCase, on="instructions") - .use(LowerCase, on="msg_history") + .use_many(OneLine(), LowerCase(), on="messages") + .use(UpperCase, on="messages") .use_many( TwoWords(on_fail=OnFailAction.REASK), ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 7c933ecea..805563c93 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, Field import guardrails as gd -from guardrails.prompt.instructions import Instructions from guardrails.prompt.prompt import Prompt +from guardrails.prompt.messages import Messages from guardrails.utils.constants import constants from guardrails.utils.prompt_utils import prompt_content_for_schema @@ -15,28 +15,34 @@ PROMPT = "Extract a string from the text" -REASK_PROMPT = """ +REASK_MESSAGES = [ + { + "role": "user", + "content": """ Please try that again, extract a string from the text ${xml_output_schema} ${previous_response} -""" +""", + } +] SIMPLE_RAIL_SPEC = f""" - + + {INSTRUCTIONS} - - - + + {PROMPT} - + + """ @@ -46,17 +52,18 @@ - + + ${user_instructions} - - - + + ${user_prompt} - + + """ @@ -66,18 +73,19 @@ - -You are a helpful bot, who answers only with valid JSON + + - +You are a helpful bot, who answers only with valid JSON - + + Extract a string from the text ${gr.complete_json_suffix_v2} - + """ @@ -122,6 +130,34 @@ """ +RAIL_WITH_REASK_MESSAGES = """ + + + + + + + + +You are a helpful bot, who answers only with valid JSON + + +${gr.complete_json_suffix_v2} + + + + + +Please try that again, extract a string from the text +${xml_output_schema} +${previous_response} + + + + +""" + + RAIL_WITH_REASK_INSTRUCTIONS = """ @@ -156,47 +192,59 @@ def test_parse_prompt(): guard = gd.Guard.from_rail_string(SIMPLE_RAIL_SPEC) # Strip both, raw and parsed, to be safe - instructions = Instructions(guard._exec_opts.instructions) - assert instructions.format().source.strip() == INSTRUCTIONS.strip() - prompt = Prompt(guard._exec_opts.prompt) - assert prompt.format().source.strip() == PROMPT.strip() + messages = Messages(source=guard._exec_opts.messages) + assert messages.format().source[0]["content"].strip() == INSTRUCTIONS.strip() + assert messages.format().source[1]["content"].strip() == PROMPT.strip() -def test_instructions_with_params(): - """Test a guard with instruction parameters.""" +def test_messages_with_params(): + """Test a guard with message parameters.""" guard = gd.Guard.from_rail_string(RAIL_WITH_PARAMS) user_instructions = "A useful system message." user_prompt = "A useful prompt." + messages = Messages(guard._exec_opts.messages) - instructions = Instructions(guard._exec_opts.instructions) assert ( - instructions.format(user_instructions=user_instructions).source.strip() + messages.format( + user_instructions=user_instructions, + user_prompt=user_prompt, + ) + .source[1]["content"] + .strip() + == user_prompt.strip() + ) + + assert ( + messages.format( + user_instructions=user_instructions, + user_prompt=user_prompt, + ) + .source[0]["content"] + .strip() == user_instructions.strip() ) - prompt = Prompt(guard._exec_opts.prompt) - assert prompt.format(user_prompt=user_prompt).source.strip() == user_prompt.strip() @pytest.mark.parametrize( "rail,var_names", [ (SIMPLE_RAIL_SPEC, []), - (RAIL_WITH_PARAMS, ["user_prompt"]), + (RAIL_WITH_PARAMS, ["user_instructions", "user_prompt"]), ], ) def test_variable_names(rail, var_names): """Test extracting variable names from a prompt.""" guard = gd.Guard.from_rail_string(rail) - prompt = Prompt(guard._exec_opts.prompt) + messages = Messages(guard._exec_opts.messages) - assert prompt.variable_names == var_names + assert messages.variable_names == var_names -def test_format_instructions(): - """Test extracting format instructions from a prompt.""" - guard = gd.Guard.from_rail_string(RAIL_WITH_FORMAT_INSTRUCTIONS) +def test_format_messages(): + """Test extracting format messages from a prompt.""" + guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_MESSAGES) output_schema = prompt_content_for_schema( guard._output_type, @@ -206,22 +254,22 @@ def test_format_instructions(): ) expected_instructions = ( - Template(constants["complete_json_suffix_v2"]) - .safe_substitute(output_schema=output_schema) - .rstrip() + Template(constants["complete_json_suffix_v2"]).safe_substitute( + output_schema=output_schema + ) + ).rstrip() + + messages = Messages( + source=guard._exec_opts.messages, + output_schema=output_schema, ) - prompt = Prompt(guard._exec_opts.prompt, output_schema=output_schema) - assert prompt.format_instructions.rstrip() == expected_instructions + assert messages.source[1]["content"].rstrip() == expected_instructions -def test_reask_prompt(): - guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_PROMPT) - assert guard._exec_opts.reask_prompt == REASK_PROMPT - -def test_reask_instructions(): - guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_INSTRUCTIONS) - assert guard._exec_opts.reask_instructions == INSTRUCTIONS +def test_reask_messages(): + guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_MESSAGES) + assert guard._exec_opts.reask_messages == REASK_MESSAGES @pytest.mark.parametrize( @@ -246,10 +294,14 @@ class TestResponse(BaseModel): def test_gr_prefixed_prompt_item_passes(): # From pydantic: - prompt = """Give me a response to ${grade}""" - - guard = gd.Guard.from_pydantic(output_class=TestResponse, prompt=prompt) - prompt = Prompt(guard._exec_opts.prompt) + messages = [ + { + "role": "user", + "content": "Give me a response to ${grade}", + } + ] + guard = gd.Guard.from_pydantic(output_class=TestResponse, messages=messages) + prompt = Messages(source=guard._exec_opts.messages) assert len(prompt.variable_names) == 1 diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 1b30aaee2..190265319 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -295,7 +295,15 @@ class Pet(BaseModel): pet_type: str = Field(description="Species of pet", validators=[validator]) name: str = Field(description="a unique pet name") - guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + guard = Guard.from_pydantic( + output_class=Pet, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) if isinstance(expected_result, type) and issubclass(expected_result, Exception): with pytest.raises(ValidationError) as excinfo: guard.parse(output, num_reasks=0) @@ -316,38 +324,14 @@ def test_input_validation_fix(mocker): def mock_llm_api(*args, **kwargs): return json.dumps({"name": "Fluffy"}) - # fix returns an amended value for prompt/instructions validation, - guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="prompt") - - guard( - mock_llm_api, - prompt="What kind of pet should I get?", - ) - assert ( - guard.history.first.iterations.first.outputs.validation_response == "What kind" - ) - guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="instructions") - - guard( - mock_llm_api, - prompt="What kind of pet should I get and what should I name it?", - instructions="But really, what kind of pet should I get?", - ) - assert ( - guard.history.first.iterations.first.outputs.validation_response - == "But really," - ) - - # but raises for msg_history validation + # raises for messages validation guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="msg_history") + guard.use(TwoWords(on_fail=OnFailAction.FIX), on="messages") with pytest.raises(ValidationError) as excinfo: guard( mock_llm_api, - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -415,7 +399,12 @@ async def mock_llm_api(*args, **kwargs): await guard( mock_llm_api, - prompt="What kind of pet should I get?", + messages=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], ) assert ( guard.history.first.iterations.first.outputs.validation_response == "What kind" @@ -434,14 +423,14 @@ async def mock_llm_api(*args, **kwargs): == "But really," ) - # but raises for msg_history validation + # but raises for messages validation guard = AsyncGuard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="msg_history") + guard.use(TwoWords(on_fail=OnFailAction.FIX), on="messages") with pytest.raises(ValidationError) as excinfo: await guard( mock_llm_api, - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -561,7 +550,12 @@ def custom_llm(*args, **kwargs): with pytest.raises(ValidationError) as excinfo: guard( custom_llm, - prompt="What kind of pet should I get?", + messages=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], ) assert str(excinfo.value) == structured_prompt_error assert isinstance(guard.history.last.exception, ValidationError) @@ -584,12 +578,12 @@ def custom_llm(*args, **kwargs): # With Msg History Validation guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=on_fail), on="msg_history") + guard.use(TwoWords(on_fail=on_fail), on="messages") with pytest.raises(ValidationError) as excinfo: guard( custom_llm, - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -713,41 +707,14 @@ async def custom_llm(*args, **kwargs): return_value=custom_llm, ) - # with_prompt_validation - guard = AsyncGuard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=on_fail), on="prompt") - - with pytest.raises(ValidationError) as excinfo: - await guard( - custom_llm, - prompt="What kind of pet should I get?", - ) - assert str(excinfo.value) == structured_prompt_error - assert isinstance(guard.history.last.exception, ValidationError) - assert guard.history.last.exception == excinfo.value - - # with_instructions_validation - guard = AsyncGuard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=on_fail), on="instructions") - - with pytest.raises(ValidationError) as excinfo: - await guard( - custom_llm, - prompt="What kind of pet should I get and what should I name it?", - instructions="What kind of pet should I get?", - ) - assert str(excinfo.value) == structured_instructions_error - assert isinstance(guard.history.last.exception, ValidationError) - assert guard.history.last.exception == excinfo.value - - # with_msg_history_validation + # with_messages_validation guard = AsyncGuard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=on_fail), on="msg_history") + guard.use(TwoWords(on_fail=on_fail), on="messages") with pytest.raises(ValidationError) as excinfo: await guard( custom_llm, - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -827,14 +794,14 @@ async def custom_llm(*args, **kwargs): def test_input_validation_mismatch_raise(): - # prompt validation, msg_history argument + # prompt validation, messages argument guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="prompt") + guard.use(TwoWords(on_fail=OnFailAction.FIX), on="messages") - with pytest.raises(ValueError): + with pytest.raises(ValidationError): guard( get_static_openai_create_func(), - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -842,14 +809,14 @@ def test_input_validation_mismatch_raise(): ], ) - # instructions validation, msg_history argument + # instructions validation, messages argument guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="instructions") + guard.use(TwoWords(on_fail=OnFailAction.FIX), on="messages") - with pytest.raises(ValueError): + with pytest.raises(ValidationError): guard( get_static_openai_create_func(), - msg_history=[ + messages=[ { "role": "user", "content": "What kind of pet should I get?", @@ -857,12 +824,17 @@ def test_input_validation_mismatch_raise(): ], ) - # msg_history validation, prompt argument + # messages validation, prompt argument guard = Guard.from_pydantic(output_class=Pet) - guard.use(TwoWords(on_fail=OnFailAction.FIX), on="msg_history") + guard.use(TwoWords(on_fail=OnFailAction.FIX), on="messages") - with pytest.raises(ValueError): + with pytest.raises(ValidationError): guard( get_static_openai_create_func(), - prompt="What kind of pet should I get?", + messages=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], )