diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 5247dda718..07b751c323 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -136,8 +136,8 @@ def copy(self, **kwargs): return new_instance - def inspect_history(self, n: int = 1): - return pretty_print_history(self.history, n) + def inspect_history(self, n: int = 1, verbose: int = 0): + return pretty_print_history(self.history, n=n, verbose=verbose) def update_global_history(self, entry): if settings.disable_history: @@ -151,4 +151,4 @@ def update_global_history(self, entry): def inspect_history(n: int = 1): """The global history shared across all LMs.""" - return pretty_print_history(GLOBAL_HISTORY, n) + return pretty_print_history(GLOBAL_HISTORY, n=n, verbose=verbose) diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index f39cc89126..20f666b6eb 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -108,8 +108,8 @@ def map_named_predictors(self, func): set_attribute_by_name(self, name, func(predictor)) return self - def inspect_history(self, n: int = 1): - return pretty_print_history(self.history, n) + def inspect_history(self, n: int = 1, verbose: int = 0): + return pretty_print_history(self.history, n=n, verbose=verbose) def batch( self, diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 10333c4eb0..c555e46fb8 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -10,16 +10,71 @@ def _blue(text: str, end: str = "\n"): return "\x1b[34m" + str(text) + "\x1b[0m" + end -def pretty_print_history(history, n: int = 1): - """Prints the last n prompts and their completions.""" +def _yellow(text: str, end: str = "\n"): + return "\x1b[33m" + str(text) + "\x1b[0m" + end - for item in history[-n:]: + +def _magenta(text: str, end: str = "\n"): + return "\x1b[35m" + str(text) + "\x1b[0m" + end + + +def flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict: + """Recursively flattens a nested dictionary.""" + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def pretty_print_dict(data: dict): + """Pretty prints a given dictionary, flattening nested dictionaries and + coloring keys and values differently.""" + try: + flattened_data = flatten_dict(data) + + for key, value in flattened_data.items(): + if value: + print(f"{_yellow(key, end='')}: {_green(repr(value), end='')}") + except Exception as e: + print(f"An error occurred during pretty printing: {e}") + + +def pretty_print_history(history, n: int = 1, verbose: int = 0): + """Prints the last n prompts and their completions. + + Args: + history (list): A list of dictionaries containing the history of prompts and completions. + n (int): The number of most recent entries to print. Defaults to 1. + verbose (int): Verbosity level. + """ + + for index, item in enumerate(history[-n:]): messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] outputs = item["outputs"] timestamp = item.get("timestamp", "Unknown time") - print("\n\n\n") - print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") + print("\n") + print(_magenta("History entry: ", end="") + _green(f"{index + 1}")) + print(_blue(f"[{timestamp}]")) + + if verbose > 0: + usage = item["response"].get("usage") + if usage: + print(_blue(f"Usage")) + pretty_print_dict(usage.to_dict()) + print("\n") + if hasattr(item["response"].choices[0], "message"): + response_message = item["response"].choices[0].message + if hasattr(response_message, "reasoning_content"): + reasoning = response_message.reasoning_content + if reasoning: + print(_blue(f"Reasoning")) + print(_green(reasoning.strip())) + print("\n") for msg in messages: print(_red(f"{msg['role'].capitalize()} message:")) @@ -60,4 +115,5 @@ def pretty_print_history(history, n: int = 1): choices_text = f" \t (and {len(outputs) - 1} other completions)" print(_red(choices_text, end="")) - print("\n\n\n") + print("-" * 50) + print("\n\n")